# SegPath Histology Image Segmentation

This notebook trains a semantic segmentation model for identifying 4 cell types in H&E-stained histology images:
1. Epithelial cells (panCK_Epithelium)
2. Smooth muscle cells (aSMA_SmoothMuscle)
3. Lymphocytes (CD3CD20_Lymphocyte) 
4. Myeloid cells (MNDA_MyeloidCell)

We'll use the SegPath dataset, which provides paired H&E images and binary masks for each cell type. The model can then be applied to Visium HD spatial transcriptomics data.

## 1. Environment Setup and Imports

In [None]:
# Install required packages
!pip install torch torchvision numpy pandas matplotlib albumentations tqdm segmentation-models-pytorch

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm.notebook import tqdm
import segmentation_models_pytorch as smp
import time
import random

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## 2. Check Device and Directories

In [None]:
# Check available device
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Metal) device for training")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using {device} for training")

In [None]:
# Define the base directory where your datasets are stored
# When running locally
base_dir = os.path.expanduser("~/Downloads")

# When running on GCP, modify this path
# base_dir = "/path/to/your/data/on/gcp"

# Define cell type directories
cell_type_dirs = {
    'epithelial': os.path.join(base_dir, 'panCK_Epithelium'),
    'smooth_muscle': os.path.join(base_dir, 'aSMA_SmoothMuscle'),
    'lymphocyte': os.path.join(base_dir, 'CD3CD20_Lymphocyte'),
    'myeloid': os.path.join(base_dir, 'MNDA_MyeloidCell')
}

# CSV file paths
csv_paths = {
    'epithelial': os.path.join(base_dir, 'panCK_fileinfo.csv'),
    'smooth_muscle': os.path.join(base_dir, 'aSMA_fileinfo.csv'),
    'lymphocyte': os.path.join(base_dir, 'CD3CD20_fileinfo.csv'),
    'myeloid': os.path.join(base_dir, 'MNDA_fileinfo.csv')
}

combined_csv_path = os.path.join(base_dir, 'combined_dataset.csv')

# Output directory for saving models and results
output_dir = os.path.join(base_dir, 'model_output')
os.makedirs(output_dir, exist_ok=True)

# Verify directories exist
for cell_type, dir_path in cell_type_dirs.items():
    if os.path.exists(dir_path):
        print(f"✓ {cell_type} directory found: {dir_path}")
    else:
        print(f"✗ {cell_type} directory not found: {dir_path}")

# Verify CSV files exist
for cell_type, csv_path in csv_paths.items():
    if os.path.exists(csv_path):
        print(f"✓ {cell_type} CSV found: {csv_path}")
    else:
        print(f"✗ {cell_type} CSV not found: {csv_path}")

## 3. Combine CSV Files from Different Cell Types

In [None]:
# Initialize an empty list to store dataframes
dfs = []

# Read each CSV and add a column indicating the cell type
for cell_type, csv_path in csv_paths.items():
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        df['cell_type'] = cell_type
        dfs.append(df)
    else:
        print(f"Warning: {csv_path} not found")

# Combine all dataframes
if dfs:
    combined_df = pd.concat(dfs, ignore_index=True)
    
    # Save the combined dataframe
    combined_df.to_csv(combined_csv_path, index=False)
    print(f"Combined CSV created with {len(combined_df)} entries")
    
    # Print some statistics
    print("\nSplit distribution:")
    print(combined_df['train_val_test'].value_counts())
    
    print("\nCell type distribution:")
    print(combined_df['cell_type'].value_counts())
else:
    print("No valid CSV files found")

## 4. Dataset Implementation

Create a PyTorch Dataset class to load the SegPath data with balancing between cell types.

In [None]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import cv2
import albumentations as A
from albumentations.pytorch import ToTensorV2

# StainNormalizer class for H&E stain normalization
class StainNormalizer:
    def __init__(self, method='macenko'):
        try:
            import staintools
            self.normalizer = staintools.StainNormalizer(method=method)
            self.is_fitted = False
            self.available = True
        except ImportError:
            print("Warning: staintools not installed. Stain normalization will be disabled.")
            self.available = False
            
    def fit(self, reference_img):
        """Fit the normalizer to a reference image"""
        if not self.available:
            return
            
        try:
            self.normalizer.fit(reference_img)
            self.is_fitted = True
            print("Stain normalizer fitted successfully")
        except Exception as e:
            print(f"Error fitting stain normalizer: {e}")
    
    def transform(self, image):
        """Apply stain normalization to an image"""
        if not self.available or not self.is_fitted:
            return image
        
        try:
            return self.normalizer.transform(image)
        except Exception as e:
            print(f"Stain normalization failed: {e}")
            return image

class SegPathDataset(Dataset):
    """Dataset for loading SegPath data with balanced sampling."""
    
    def __init__(self, 
                 csv_file, 
                 base_dir,
                 split='train', 
                 samples_per_cell_type=2000,  # Increased sample size
                 seed=42,
                 transform=None,
                 use_stain_normalization=True):
        """
        Args:
            csv_file: Path to the combined CSV file
            base_dir: Base directory containing cell type folders
            split: 'train', 'val', or 'test'
            samples_per_cell_type: How many samples to take from each cell type
            seed: Random seed for reproducibility
            transform: Optional transform to apply
            use_stain_normalization: Whether to apply stain normalization
        """
        # Read the combined CSV file
        self.df = pd.read_csv(csv_file)
        self.base_dir = base_dir

        # Filter by the specified split
        self.df = self.df[self.df['train_val_test'] == split]

        # Only keep mask files (we'll derive HE image paths from these)
        mask_files = self.df[self.df['filename'].str.contains('_mask.png')]

        # Map cell types to numerical labels
        self.cell_type_mapping = {
            'epithelial': 1,
            'smooth_muscle': 2,
            'lymphocyte': 3,
            'myeloid': 4
        }

        # Set random seed for reproducibility
        np.random.seed(seed)
        
        # Initialize stain normalizer if requested
        self.use_stain_normalization = use_stain_normalization
        self.normalizer = None
        
        if use_stain_normalization:
            self.normalizer = StainNormalizer(method='macenko')
            
            # Find a reference image (first epithelial sample)
            for _, row in self.df[self.df['cell_type'] == 'epithelial'].head(1).iterrows():
                he_filename = row['filename'].replace('_mask.png', '_HE.png')
                ref_img_path = os.path.join(base_dir, he_filename)
                if os.path.exists(ref_img_path):
                    ref_img = cv2.imread(ref_img_path)
                    if ref_img is not None:
                        ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
                        self.normalizer.fit(ref_img)
                        break

        # Sample files for each cell type
        self.samples = []
        self.corrupted_files = []  # Keep track of corrupted files

        for cell_type in self.cell_type_mapping.keys():
            # Get mask files for this cell type
            cell_type_masks = mask_files[mask_files['cell_type'] == cell_type]

            # Determine sample size (min of requested samples or available samples)
            sample_size = min(samples_per_cell_type, len(cell_type_masks))

            # Randomly sample rows
            sampled_masks = cell_type_masks.sample(n=sample_size, random_state=seed)

            # Add to our samples list
            valid_samples_count = 0
            for _, row in sampled_masks.iterrows():
                mask_path = os.path.join(base_dir, row['filename'])

                # Derive HE image path from mask path
                he_filename = row['filename'].replace('_mask.png', '_HE.png')
                he_path = os.path.join(base_dir, he_filename)

                # Check if both files exist
                if os.path.exists(he_path) and os.path.exists(mask_path):
                    # Try to actually open the files to check for corruption
                    try:
                        # Attempt to read the image files
                        img = cv2.imread(he_path)
                        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

                        # If either file is corrupted, imread will return None
                        if img is None or mask is None:
                            self.corrupted_files.append((he_path, mask_path))
                            continue
                        
                        # Also check if the image dimensions match our expectations
                        if img.shape[0] != 984 or img.shape[1] != 984:
                            print(f"Warning: Unexpected image dimensions {img.shape} for {he_path}")
                            self.corrupted_files.append((he_path, mask_path))
                            continue
                        
                        # If we made it here, the files are valid
                        cell_type_label = self.cell_type_mapping[cell_type]
                        self.samples.append((he_path, mask_path, cell_type_label))
                        valid_samples_count += 1

                    except Exception as e:
                        print(f"Error reading files {he_path} or {mask_path}: {str(e)}")
                        self.corrupted_files.append((he_path, mask_path))

        # Print information about samples and corrupted files
        print(f"Loaded {len(self.samples)} samples for {split} split")
        print(f"Found {len(self.corrupted_files)} corrupted files that were skipped")

        # Count samples by cell type
        cell_type_counts = {}
        for _, _, label in self.samples:
            cell_type_counts[label] = cell_type_counts.get(label, 0) + 1

        for label, count in cell_type_counts.items():
            cell_type_name = [k for k, v in self.cell_type_mapping.items() if v == label][0]
            print(f"  {cell_type_name}: {count} samples")
            
        # Set transform
        self.transform = transform if transform is not None else self._get_default_transform(split)
    
    def _get_default_transform(self, phase):
        """Get default transforms if none are provided"""
        if phase == 'train':
            return A.Compose([
                # Add padding to make dimensions divisible by 32
                A.PadIfNeeded(min_height=992, min_width=992, border_mode=cv2.BORDER_CONSTANT, value=0),
                
                # Spatial transforms - stronger and more frequent
                A.RandomRotate90(p=0.7),
                A.HorizontalFlip(p=0.7),
                A.VerticalFlip(p=0.5),
                A.ShiftScaleRotate(
                    shift_limit=0.1, 
                    scale_limit=0.2,  # Increased scale variation
                    rotate_limit=30,  # Increased rotation range
                    p=0.7            # Higher probability of applying
                ),
                
                # Color augmentation - critical for H&E images
                A.OneOf([
                    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=25, p=0.7),
                    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
                    A.CLAHE(clip_limit=4.0, p=0.7),  # Contrast Limited Adaptive Histogram Equalization
                ], p=0.8),
                
                # Add some noise/blur to simulate microscopy artifacts
                A.OneOf([
                    A.GaussianBlur(blur_limit=3, p=0.5),
                    A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
                ], p=0.5),
                
                # Normalization and conversion to tensor
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2(),
            ])
        else:  # validation or test
            return A.Compose([
                # Add padding to make dimensions divisible by 32
                A.PadIfNeeded(min_height=992, min_width=992, border_mode=cv2.BORDER_CONSTANT, value=0),
                
                A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ToTensorV2(),
            ])
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        # Try to get a valid sample, with a limit on retries to avoid infinite loops
        max_retries = 5
        for attempt in range(max_retries):
            try:
                # Get the paths and label
                img_path, mask_path, label = self.samples[idx]
                
                # Load the image and mask
                image = cv2.imread(img_path)
                if image is None:
                    print(f"Warning: Failed to load image {img_path} - will try another sample")
                    # Try another index
                    idx = (idx + 1) % len(self.samples)
                    continue
                    
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                
                # Apply stain normalization if enabled
                if self.use_stain_normalization and self.normalizer is not None:
                    image = self.normalizer.transform(image)
                
                mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                if mask is None:
                    print(f"Warning: Failed to load mask {mask_path} - will try another sample")
                    # Try another index
                    idx = (idx + 1) % len(self.samples)
                    continue
                
                # Convert binary mask (0/1) to the specific cell type label
                labeled_mask = np.zeros_like(mask)
                labeled_mask[mask == 1] = label
                
                # Apply transformations
                transformed = self.transform(image=image, mask=labeled_mask)
                image = transformed['image']
                labeled_mask = transformed['mask']
                
                return {
                    'image': image,
                    'mask': labeled_mask,
                    'cell_type': label,
                    'path': img_path
                }
            except Exception as e:
                print(f"Error in __getitem__ for {idx}: {str(e)} - will try another sample")
                idx = (idx + 1) % len(self.samples)
        
        # If we've exhausted all retries, create a dummy sample
        # This is a last resort to avoid crashing the training loop
        print(f"WARNING: Failed to load any valid sample after {max_retries} attempts")
        dummy_image = torch.zeros(3, 992, 992)  # Padded dimensions
        dummy_mask = torch.zeros(992, 992).long()
        
        return {
            'image': dummy_image,
            'mask': dummy_mask,
            'cell_type': 1,  # Default to epithelial
            'path': 'dummy_path'
        }

## 5. Visualize Some Sample Data

In [None]:
# Create datasets for each split with more samples
train_dataset = SegPathDataset(
    csv_file=combined_csv_path,
    base_dir=base_dir,
    split='train',
    samples_per_cell_type=2000  # Increased from 500 to 2000
)

val_dataset = SegPathDataset(
    csv_file=combined_csv_path,
    base_dir=base_dir,
    split='val',
    samples_per_cell_type=300  # Increased from 100 to 300
)

test_dataset = SegPathDataset(
    csv_file=combined_csv_path,
    base_dir=base_dir,
    split='test',
    samples_per_cell_type=300  # Increased from 100 to 300
)

# Create dataloaders
batch_size = 4  # Smaller batch size for MacBook Pro
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_batch(loader, title):
    """
    Visualize a batch of images and corresponding masks from a DataLoader.
    
    Args:
        loader: DataLoader object.
        title: Title for the plot.
    """
    # Get a batch of samples
    batch = next(iter(loader))
    images = batch['image'].cpu().numpy()  # shape: [B, C, H, W]
    masks = batch['mask'].cpu().numpy()      # shape: [B, H, W]
    cell_types = batch['cell_type'].cpu().numpy()  # shape: [B]
    
    # Mapping numeric labels to names
    cell_type_names = {
        1: 'epithelial',
        2: 'smooth_muscle',
        3: 'lymphocyte',
        4: 'myeloid'
    }
    
    num_images = images.shape[0]
    # Create a figure with two columns: one for image, one for mask.
    fig, axes = plt.subplots(num_images, 2, figsize=(10, 5 * num_images))
    
    # If there's only one sample in the batch, make axes a list of two elements.
    if num_images == 1:
        axes = [axes]
    
    # Denormalize the images for display
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    
    for i in range(num_images):
        # Convert tensor to image format (H x W x C)
        img = images[i].transpose(1, 2, 0)
        # Denormalize
        img = std * img + mean
        img = np.clip(img, 0, 1)
        
        # Plot H&E image
        axes[i][0].imshow(img)
        axes[i][0].set_title(f"H&E - {cell_type_names.get(cell_types[i], 'Unknown')}")
        axes[i][0].axis("off")
        
        # Plot mask (using 'viridis' colormap)
        axes[i][1].imshow(masks[i], cmap='viridis', vmin=0, vmax=4)
        axes[i][1].set_title("Mask")
        axes[i][1].axis("off")
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()


In [None]:
# Visualize samples from training set
visualize_batch(train_loader, "Training Samples")

In [None]:
# Visualize samples from validation set
visualize_batch(val_loader, "Validation Samples")

In [None]:
import os
import shutil
import zipfile
from tqdm import tqdm

def export_sampled_images(dataset, output_root, split_name, create_zip=True):
    """
    Copies H&E images and masks from the dataset.samples list to a structured folder:
        output_root / split_name / {cell_type_name}/
    Args:
        dataset: A SegPathDataset instance (already constructed)
        output_root: Base directory for exporting the images
        split_name: 'train', 'val', or 'test' (used to create subfolders)
        create_zip: Whether to create a zip file of the output directory
    """
    print(f"\nExporting {len(dataset.samples)} {split_name} samples to {output_root} ...")
    
    # Ensure the split subfolder exists
    split_folder = os.path.join(output_root, split_name)
    os.makedirs(split_folder, exist_ok=True)
    
    for (he_path, mask_path, label) in tqdm(dataset.samples):
        # Convert numeric label back to cell type name
        cell_type_name = [
            k for k, v in dataset.cell_type_mapping.items() if v == label
        ][0]
        
        # Make a subfolder for each cell type under train/val/test
        cell_type_folder = os.path.join(split_folder, cell_type_name)
        os.makedirs(cell_type_folder, exist_ok=True)
        
        # Copy the files
        he_filename = os.path.basename(he_path)
        mask_filename = os.path.basename(mask_path)
        
        # Destination paths
        he_dest = os.path.join(cell_type_folder, he_filename)
        mask_dest = os.path.join(cell_type_folder, mask_filename)
        
        # Copy the H&E and mask
        shutil.copy2(he_path, he_dest)
        shutil.copy2(mask_path, mask_dest)
    
    print(f"Export complete for {split_name} set!")
    
    # Create a zip file if requested
    if create_zip:
        print(f"Creating zip file for {split_name} set...")
        zip_path = os.path.join(output_root, f"segpath_{split_name}.zip")
        with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
            for root, _, files in os.walk(split_folder):
                for file in files:
                    file_path = os.path.join(root, file)
                    arcname = os.path.relpath(file_path, output_root)
                    zipf.write(file_path, arcname)
        print(f"Zip file created at: {zip_path}")

# Create the top-level 'sampled' directory somewhere you have write access
sampled_dir = os.path.join(base_dir, "sampled")
os.makedirs(sampled_dir, exist_ok=True)

# Export train set
export_sampled_images(train_dataset, sampled_dir, "train")

# Export val set
export_sampled_images(val_dataset, sampled_dir, "val")

# Export test set
export_sampled_images(test_dataset, sampled_dir, "test")

# Optional: Create a single combined zip file with all splits
print("\nCreating combined zip file...")
combined_zip_path = os.path.join(base_dir, "segpath_all_splits.zip")
with zipfile.ZipFile(combined_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(sampled_dir):
        for file in files:
            if not file.endswith('.zip'):  # Skip the individual zip files
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, os.path.dirname(sampled_dir))
                zipf.write(file_path, arcname)
print(f"Combined zip file created at: {combined_zip_path}")

## 6. Model Definition

We'll use a UNet model from segmentation_models_pytorch with ResNet34 backbone.

In [None]:
# Our model will output 5 classes:
# 0: background, 1: epithelial, 2: smooth muscle, 3: lymphocyte, 4: myeloid
num_classes = 5

model = smp.Unet(
    encoder_name="resnet50",      # Upgrade from resnet34 to resnet50 for better feature extraction
    encoder_weights="imagenet",   # Use ImageNet pre-training
    in_channels=3,                # RGB images
    classes=num_classes,          # 5 classes (background + 4 cell types)
    decoder_attention_type="scse"  # Add spatial and channel squeeze & excitation for better feature refinement
)

model = model.to(device)
print(model)

## 7. Helper Functions for Training and Evaluation

In [None]:
def calculate_iou(pred, target, n_classes=5):
    """Calculate mean IoU (Intersection over Union) for cell type classes"""
    ious = []
    
    # Skip background class (0)
    for cls in range(1, n_classes):
        # Binary masks for this class
        pred_binary = (pred == cls).astype(np.uint8)
        target_binary = (target == cls).astype(np.uint8)
        
        # Calculate intersection and union
        intersection = np.logical_and(pred_binary, target_binary).sum()
        union = np.logical_or(pred_binary, target_binary).sum()
        
        # Calculate IoU for this class
        if union > 0:
            iou = intersection / union
        else:
            # If this class doesn't appear in the ground truth or prediction
            iou = 1.0 if intersection == 0 else 0.0
        
        ious.append(iou)
    
    # Return mean IoU
    return np.mean(ious)
    

In [None]:
def visualize_predictions(model, test_loader, device, num_samples=4):
    """Visualize model predictions on test data"""
    model.eval()
    
    # Get a batch
    batch = next(iter(test_loader))
    images = batch['image'].to(device)
    masks = batch['mask'].cpu().numpy()
    
    # Make predictions
    with torch.no_grad():
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1).cpu().numpy()
    
    # Convert images back for visualization
    images = images.cpu().numpy()
    
    # Map class labels to colors for better visualization
    colors = {
        0: [0, 0, 0],       # Background: black
        1: [255, 0, 0],     # Epithelial: red
        2: [0, 255, 0],     # Smooth muscle: green
        3: [0, 0, 255],     # Lymphocyte: blue
        4: [255, 255, 0]    # Myeloid: yellow
    }
    
    # Map class labels to names
    class_names = {
        0: 'Background',
        1: 'Epithelial',
        2: 'Smooth muscle',
        3: 'Lymphocyte',
        4: 'Myeloid'
    }
    
    # Create a colormap function
    def mask_to_rgb(mask):
        rgb = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
        for i in range(5):  # 5 classes including background
            rgb[mask == i] = colors[i]
        return rgb
    
    # Plot the results
    n = min(num_samples, len(images))
    plt.figure(figsize=(15, 5 * n))
    
    for i in range(n):
        # Get image, true mask, and predicted mask
        img = images[i].transpose(1, 2, 0)
        true_mask = masks[i]
        pred_mask = preds[i]
        
        # Denormalize image
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)
        
        # Convert masks to RGB
        true_mask_rgb = mask_to_rgb(true_mask)
        pred_mask_rgb = mask_to_rgb(pred_mask)
        
        # Calculate IoU
        iou = calculate_iou(pred_mask, true_mask)
        
        # Get true cell type
        cell_type = batch['cell_type'][i].item()
        cell_type_name = class_names[cell_type]
        
        # Plot
        plt.subplot(n, 3, i*3 + 1)
        plt.imshow(img)
        plt.title(f'H&E Image - {cell_type_name}')
        plt.axis('off')
        
        plt.subplot(n, 3, i*3 + 2)
        plt.imshow(true_mask_rgb)
        plt.title('True Mask')
        plt.axis('off')
        
        plt.subplot(n, 3, i*3 + 3)
        plt.imshow(pred_mask_rgb)
        plt.title(f'Predicted Mask (IoU: {iou:.2f})')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## 8. Plot Training History Function

In [None]:
def plot_training_history(history):
    """Plot training and validation metrics"""
    epochs = range(1, len(history['train_loss']) + 1)
    
    plt.figure(figsize=(12, 5))
    
    # Plot training and validation loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['train_loss'], 'b-', label='Training Loss')
    plt.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Plot validation IoU
    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['val_iou'], 'g-', label='Validation IoU')
    plt.title('Validation IoU')
    plt.xlabel('Epochs')
    plt.ylabel('IoU')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

## 9. Model Training Function

In [None]:
from segmentation_models_pytorch.losses import DiceLoss

class CombinedLoss(nn.Module):
    def __init__(self, weight=None, ce_weight=0.7, dice_weight=0.3):
        super(CombinedLoss, self).__init__()
        self.ce = nn.CrossEntropyLoss(weight=weight)
        self.dice = DiceLoss(mode='multiclass')
        self.ce_weight = ce_weight
        self.dice_weight = dice_weight
        self.num_classes = 5  # background + 4 cell types
        
    def forward(self, outputs, targets):
        # Cross entropy loss
        ce_loss = self.ce(outputs, targets)
        
        # For Dice loss, we need to convert targets to one-hot encoding
        # First, ensure targets are long type
        targets_long = targets.long()
        
        # One-hot encode the targets
        targets_one_hot = F.one_hot(targets_long, num_classes=self.num_classes).permute(0, 3, 1, 2).float()
        
        # Calculate Dice loss
        dice_loss = self.dice(outputs, targets_one_hot)
        
        # Combine losses
        return self.ce_weight * ce_loss + self.dice_weight * dice_loss

def train_model(model, train_loader, val_loader, device, num_epochs=10, learning_rate=1e-4, class_weights=None):
    """Train the segmentation model with improved learning strategies"""
    # Calculate class weights based on dataset distribution
    class_samples = {
        'background': 1000000,  # Estimate for background pixels
        'epithelial': 53018,
        'smooth_muscle': 62356,
        'lymphocyte': 24546,
        'myeloid': 28270
    }
    
    # Calculate weights (inversely proportional to frequency)
    if class_weights is None:
        num_classes = 5  # background + 4 cell types
        class_weights = [1.0] * num_classes
        for i, class_name in enumerate(['background', 'epithelial', 'smooth_muscle', 'lymphocyte', 'myeloid']):
            if i > 0:  # Skip background or assign small weight
                class_weights[i] = max(class_samples.values()) / class_samples[class_name]
        
        # Normalize weights
        total = sum(class_weights)
        class_weights = [w/total*len(class_weights) for w in class_weights]
    
    # Convert to tensor and move to device
    class_weights_tensor = torch.tensor(class_weights).float().to(device)
    print("Class weights:", class_weights)
    
    # Use standard CrossEntropyLoss with class weights
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    
    # Optimizer with weight decay for regularization
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    
    # Initialize training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_iou': [],
        'lr': []
    }
    
    # Track best model
    best_val_loss = float('inf')
    best_model_weights = None
    
    # Early stopping parameters with increased patience
    patience = 10  # Increased from 5
    no_improve_epochs = 0
    
    # Start training
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Training phase
        model.train()
        train_loss = 0.0
        
        # Progress bar for training
        train_pbar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}')
        
        for batch in train_pbar:
            # Move batch to device
            images = batch['image'].to(device)
            masks = batch['mask'].to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping to prevent exploding gradients
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Update weights
            optimizer.step()
            
            # Update statistics
            train_loss += loss.item() * images.size(0)
            train_pbar.set_postfix({'loss': loss.item()})
        
        # Calculate epoch loss
        epoch_train_loss = train_loss / len(train_loader.dataset)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_iou_scores = []
        class_ious = {i: [] for i in range(1, 5)}  # Track IoU for each class separately
        
        # Progress bar for validation
        val_pbar = tqdm(val_loader, desc=f'Validation Epoch {epoch+1}')
        
        with torch.no_grad():
            for batch in val_pbar:
                # Move batch to device
                images = batch['image'].to(device)
                masks = batch['mask'].to(device)
                
                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, masks)
                
                # Update statistics
                val_loss += loss.item() * images.size(0)
                
                # Calculate IoU for each image in batch
                preds = torch.argmax(outputs, dim=1)
                for i in range(preds.size(0)):
                    # Calculate mean IoU across all classes
                    iou = calculate_iou(preds[i].cpu().numpy(), masks[i].cpu().numpy())
                    val_iou_scores.append(iou)
                    
                    # Calculate per-class IoU
                    for cls in range(1, 5):  # Skip background
                        # Create binary masks for this class
                        pred_binary = (preds[i].cpu().numpy() == cls).astype(np.uint8)
                        target_binary = (masks[i].cpu().numpy() == cls).astype(np.uint8)
                        
                        # Calculate intersection and union
                        intersection = np.logical_and(pred_binary, target_binary).sum()
                        union = np.logical_or(pred_binary, target_binary).sum()
                        
                        # Calculate IoU for this class
                        if union > 0:
                            cls_iou = intersection / union
                        else:
                            # If this class doesn't appear in the ground truth or prediction
                            cls_iou = 1.0 if intersection == 0 else 0.0
                        
                        class_ious[cls].append(cls_iou)
                
                val_pbar.set_postfix({'loss': loss.item()})
        
        # Calculate epoch metrics
        epoch_val_loss = val_loss / len(val_loader.dataset)
        epoch_val_iou = np.mean(val_iou_scores)
        
        # Update learning rate
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step(epoch_val_loss)
        
        # Save history
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['val_iou'].append(epoch_val_iou)
        history['lr'].append(current_lr)
        
        # Print epoch statistics
        print(f'Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val IoU: {epoch_val_iou:.4f} | LR: {current_lr:.2e}')
        
        # Print class-specific IoUs
        for cls, ious in class_ious.items():
            if ious:  # Check if we have any valid IoUs for this class
                cls_name = ['', 'Epithelial', 'Smooth Muscle', 'Lymphocyte', 'Myeloid'][cls]
                cls_iou = np.mean(ious)
                print(f'  {cls_name} IoU: {cls_iou:.4f}')
        
        # Save best model and check for early stopping
        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            best_model_weights = model.state_dict().copy()
            no_improve_epochs = 0
            print(f'New best model saved! (Val Loss: {best_val_loss:.4f})')
        else:
            no_improve_epochs += 1
            
        # Early stopping
        if no_improve_epochs >= patience:
            print(f'Early stopping after {epoch+1} epochs!')
            break
    
    # Load best model weights
    print(f'Training complete. Best Val Loss: {best_val_loss:.4f}')
    if best_model_weights:
        model.load_state_dict(best_model_weights)
    
    return model, history

# Helper function for per-class IoU
def calculate_class_iou(pred, target, cls):
    """Calculate IoU for a specific class"""
    # Create binary masks for this class
    pred_binary = (pred == cls).astype(np.uint8)
    target_binary = (target == cls).astype(np.uint8)
    
    # Calculate intersection and union
    intersection = np.logical_and(pred_binary, target_binary).sum()
    union = np.logical_or(pred_binary, target_binary).sum()
    
    # Calculate IoU for this class
    if union > 0:
        iou = intersection / union
    else:
        # If this class doesn't appear in the ground truth or prediction
        iou = 1.0 if intersection == 0 else 0.0
    
    return iou

## 10. Run the Training Process

In [None]:
# Run the training process
def train_segmentation_model(model, train_loader, val_loader, test_loader, device, 
                           num_epochs=10, learning_rate=1e-4, output_dir='model_output'):
    """Run the full training process and evaluate the model"""
    print("Starting training...")
    
    # Train the model
    trained_model, history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        num_epochs=num_epochs,
        learning_rate=learning_rate
    )
    
    # Save the trained model
    os.makedirs(output_dir, exist_ok=True)
    model_path = os.path.join(output_dir, 'segpath_model.pth')
    torch.save(trained_model.state_dict(), model_path)
    print(f"Model saved to {model_path}")
    
    # Plot training history
    plot_training_history(history)
    
    # Visualize predictions on test set
    print("Visualizing predictions on test set...")
    visualize_predictions(trained_model, test_loader, device)
    
    return trained_model

# Start the training process
trained_model = train_segmentation_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    device=device,
    num_epochs=10,  # Start with 10 epochs
    learning_rate=1e-4,
    output_dir=output_dir
)

## 11. Apply Trained Model to New Images

In [None]:
def apply_to_new_image(model, image_path, device):
    """Apply the trained model to a new image"""
    # Load the image
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Failed to load image at {image_path}")
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Get original dimensions
    h, w, _ = image.shape
    print(f"Original image size: {h}x{w}")
    
    # Calculate padding needed to make dimensions divisible by 32
    new_h = ((h + 31) // 32) * 32
    new_w = ((w + 31) // 32) * 32
    
    # Apply transforms with padding
    transform = A.Compose([
        A.PadIfNeeded(min_height=new_h, min_width=new_w, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    # Transform the image
    transformed = transform(image=image)
    image_tensor = transformed['image'].unsqueeze(0).to(device)
    
    # Check if dimensions are suitable
    print(f"Padded tensor size: {image_tensor.shape}")
    
    # Make prediction
    model.eval()
    with torch.no_grad():
        try:
            output = model(image_tensor)
            pred = torch.argmax(output, dim=1).cpu().numpy()[0]
            # Crop back to original size
            pred = pred[:h, :w]
        except RuntimeError as e:
            if "CUDA out of memory" in str(e) or "MPS out of memory" in str(e):
                print("Image too large for direct processing, use apply_to_large_image instead")
                raise e
            else:
                raise e
    
    # Map class labels to colors
    colors = {
        0: [0, 0, 0],       # Background: black
        1: [255, 0, 0],     # Epithelial: red
        2: [0, 255, 0],     # Smooth muscle: green
        3: [0, 0, 255],     # Lymphocyte: blue
        4: [255, 255, 0]    # Myeloid: yellow
    }
    
    # Create RGB mask
    pred_rgb = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
    for i in range(5):
        pred_rgb[pred == i] = colors[i]
    
    # Display the results
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(pred_rgb)
    plt.title('Predicted Segmentation')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    return pred

In [None]:
def test_model_on_random_samples(model, base_dir, device, num_samples=5):
    """Test the model on random samples from the SegPath dataset folders"""
    # Define the cell type folders
    cell_type_dirs = {
        'epithelial': os.path.join(base_dir, 'panCK_Epithelium'),
        'smooth_muscle': os.path.join(base_dir, 'aSMA_SmoothMuscle'),
        'lymphocyte': os.path.join(base_dir, 'CD3CD20_Lymphocyte'),
        'myeloid': os.path.join(base_dir, 'MNDA_MyeloidCell')
    }
    
    # Get all H&E image paths
    he_images = []
    for cell_type, directory in cell_type_dirs.items():
        if not os.path.exists(directory):
            continue
        
        # Find all HE images in this directory
        for filename in os.listdir(directory):
            if filename.endswith('_HE.png'):
                he_path = os.path.join(directory, filename)
                mask_path = os.path.join(directory, filename.replace('_HE.png', '_mask.png'))
                
                if os.path.exists(mask_path):
                    # Store the image path, mask path, and cell type
                    he_images.append((he_path, mask_path, cell_type))
    
    print(f"Found {len(he_images)} image pairs across all cell types")
    
    # Randomly select samples
    import random
    random.seed(42)  # For reproducibility
    if len(he_images) < num_samples:
        selected_samples = he_images
    else:
        selected_samples = random.sample(he_images, num_samples)
    
    # Process each selected sample
    for idx, (img_path, mask_path, cell_type) in enumerate(selected_samples):
        print(f"\nSample {idx+1}/{len(selected_samples)}: {cell_type}")
        print(f"Image: {os.path.basename(img_path)}")
        
        try:
            # Load the ground truth mask
            ground_truth = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            
            # Apply the model to get prediction
            prediction = apply_to_new_image(model, img_path, device)
            
            # Load the original image for visualization
            original = cv2.imread(img_path)
            original = cv2.cvtColor(original, cv2.COLOR_BGR2RGB)
            
            # Calculate IoU for this sample's cell type
            cell_type_label = {
                'epithelial': 1,
                'smooth_muscle': 2,
                'lymphocyte': 3, 
                'myeloid': 4
            }[cell_type]
            
            # Prepare binary masks for IoU calculation
            gt_binary = (ground_truth == 1).astype(np.uint8)  # In ground truth, 1 indicates the target cell
            pred_binary = (prediction == cell_type_label).astype(np.uint8)  # In prediction, we look for the specific cell type
            
            # Calculate IoU
            intersection = np.logical_and(gt_binary, pred_binary).sum()
            union = np.logical_or(gt_binary, pred_binary).sum()
            if union > 0:
                iou = intersection / union
            else:
                iou = 0.0
            
            print(f"IoU score: {iou:.4f}")
            
            # Visualize original, ground truth, and prediction
            plt.figure(figsize=(15, 5))
            
            plt.subplot(1, 3, 1)
            plt.imshow(original)
            plt.title('Original H&E')
            plt.axis('off')
            
            plt.subplot(1, 3, 2)
            plt.imshow(gt_binary, cmap='gray')
            plt.title(f'Ground Truth ({cell_type})')
            plt.axis('off')
            
            plt.subplot(1, 3, 3)
            # Use a colormap for prediction
            colors = np.zeros((prediction.shape[0], prediction.shape[1], 3), dtype=np.uint8)
            for i in range(5):  # 5 classes including background
                if i == 0:  # Background
                    colors[prediction == i] = [0, 0, 0]  # Black
                elif i == cell_type_label:  # Target cell type
                    colors[prediction == i] = [255, 0, 0]  # Red
                else:  # Other cell types
                    colors[prediction == i] = [0, 255, 0]  # Green
            
            plt.imshow(colors)
            plt.title(f'Prediction (IoU: {iou:.2f})')
            plt.axis('off')
            
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"Error processing {img_path}: {str(e)}")
    
    return

# Example usage
test_model_on_random_samples(trained_model, base_dir, device, num_samples=3)

## 12. Apply to Visium HD Images

This section shows how to use the trained model to segment cell types in Visium HD bladder tissue images.
Here we'll implement tiling for large images and mapping segmentation results to Visium spots.

In [None]:
def test_tma_with_different_scales(model, image_path, device, scales=[1.0, 0.5, 0.25]):
    """Test the model with different image scales to find the best match for TMA images"""
    # Load the image
    image = cv2.imread(image_path)
    if image is None:
        print(f"Could not load image: {image_path}")
        return
    
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Take a center crop to speed up processing
    h, w = image.shape[:2]
    center_y, center_x = h // 2, w // 2
    crop_size = 2000  # Larger crop to allow for scaling
    crop = image[
        max(0, center_y - crop_size//2):min(h, center_y + crop_size//2),
        max(0, center_x - crop_size//2):min(w, center_x + crop_size//2)
    ]
    
    # Display the original crop
    plt.figure(figsize=(10, 10))
    plt.imshow(crop)
    plt.title(f"Original Image Crop from {os.path.basename(image_path)}")
    plt.axis('off')
    plt.show()
    
    class_names = ['Background', 'Epithelial', 'Smooth Muscle', 'Lymphocyte', 'Myeloid']
    colors = {
        0: [0, 0, 0],       # Background: black
        1: [255, 0, 0],     # Epithelial: red
        2: [0, 255, 0],     # Smooth muscle: green
        3: [0, 0, 255],     # Lymphocyte: blue
        4: [255, 255, 0]    # Myeloid: yellow
    }
    
    results = {}
    best_scale = None
    best_non_bg_percentage = 0
    
    for scale in scales:
        print(f"\nTesting scale: {scale:.2f}")
        
        # Resize the image according to scale
        if scale != 1.0:
            scaled_image = cv2.resize(crop, None, fx=scale, fy=scale)
        else:
            scaled_image = crop.copy()
        
        # Display scaled image
        plt.figure(figsize=(8, 8))
        plt.imshow(scaled_image)
        plt.title(f"Scale: {scale:.2f}")
        plt.axis('off')
        plt.show()
        
        # Make dimensions divisible by 32
        h, w = scaled_image.shape[:2]
        new_h = ((h + 31) // 32) * 32
        new_w = ((w + 31) // 32) * 32
        
        # Apply transforms
        transform = A.Compose([
            A.PadIfNeeded(min_height=new_h, min_width=new_w, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        transformed = transform(image=scaled_image)
        image_tensor = transformed['image'].unsqueeze(0).to(device)
        
        # Get model output
        model.eval()
        with torch.no_grad():
            output = model(image_tensor)
        
        # Get raw probabilities for all classes
        probs = torch.softmax(output, dim=1).cpu().numpy()[0]
        
        # Visualize probability maps for each class
        plt.figure(figsize=(15, 10))
        for i in range(5):
            plt.subplot(2, 3, i+1)
            plt.imshow(probs[i, :scaled_image.shape[0], :scaled_image.shape[1]], cmap='viridis', vmin=0, vmax=1)
            plt.title(f'{class_names[i]} Probability')
            plt.colorbar()
            plt.axis('off')
        plt.tight_layout()
        plt.show()
        
        # Get prediction (argmax)
        pred = np.argmax(probs, axis=0)[:h, :w]
        
        # Count predictions
        unique, counts = np.unique(pred, return_counts=True)
        class_counts = {class_names[i]: 0 for i in range(5)}
        for cls, count in zip(unique, counts):
            class_counts[class_names[cls]] = count
        
        # Calculate total non-background percentage
        total_pixels = pred.size
        bg_pixels = class_counts['Background']
        non_bg_pixels = total_pixels - bg_pixels
        non_bg_percentage = (non_bg_pixels / total_pixels) * 100
        
        # Check if this is the best scale so far (most non-background pixels)
        if non_bg_percentage > best_non_bg_percentage:
            best_non_bg_percentage = non_bg_percentage
            best_scale = scale
        
        # Print result
        print("Class distribution:")
        for cls, count in class_counts.items():
            percentage = (count / pred.size) * 100
            print(f"{cls}: {count} pixels ({percentage:.2f}%)")
        
        # Create visualization mask
        vis_mask = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
        for i in range(5):
            vis_mask[pred == i] = colors[i]
        
        # Create overlay with original image
        alpha = 0.5  # Transparency factor
        overlay = scaled_image.copy()
        for i in range(1, 5):  # Skip background
            mask = pred == i
            if np.any(mask):
                overlay[mask] = alpha * overlay[mask] + (1 - alpha) * np.array(colors[i])
        
        # Create a figure with original, prediction, and overlay
        plt.figure(figsize=(20, 7))
        
        # Original image
        plt.subplot(1, 3, 1)
        plt.imshow(scaled_image)
        plt.title(f"Scaled Image ({scale:.2f}x)")
        plt.axis('off')
        
        # Segmentation mask
        plt.subplot(1, 3, 2)
        plt.imshow(vis_mask)
        plt.title(f"Segmentation (Scale: {scale:.2f})")
        plt.axis('off')
        
        # Overlay with legend
        plt.subplot(1, 3, 3)
        plt.imshow(overlay)
        plt.title(f"Overlay (Non-BG: {non_bg_percentage:.1f}%)")
        plt.axis('off')
        
        # Create legend elements
        from matplotlib.patches import Patch
        legend_elements = []
        for i in range(1, 5):  # Skip background
            if i in unique:
                # Convert RGB [0-255] to matplotlib's [0-1] scale
                color_rgb = [c/255 for c in colors[i]]
                percentage = (class_counts[class_names[i]] / pred.size) * 100
                legend_elements.append(
                    Patch(facecolor=color_rgb, 
                          label=f"{class_names[i]} ({percentage:.1f}%)")
                )
        
        # Add legend
        plt.figlegend(handles=legend_elements, loc='lower center', ncol=4, 
                     bbox_to_anchor=(0.5, -0.05), frameon=True, fancybox=True, shadow=True)
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.15)  # Make room for the legend
        plt.show()
        
        # If any non-background classes are found, create a pie chart
        if non_bg_pixels > 0:
            plt.figure(figsize=(8, 8))
            
            # Prepare data for pie chart (excluding background)
            non_bg_counts = {k: v for k, v in class_counts.items() if k != 'Background' and v > 0}
            
            if non_bg_counts:
                labels = list(non_bg_counts.keys())
                sizes = list(non_bg_counts.values())
                
                # Corresponding colors for the pie chart
                pie_colors = []
                for label in labels:
                    idx = class_names.index(label)
                    color_rgb = [c/255 for c in colors[idx]]
                    pie_colors.append(color_rgb)
                
                # Create pie chart
                plt.pie(sizes, labels=labels, colors=pie_colors, autopct='%1.1f%%', 
                       shadow=True, startangle=90)
                plt.axis('equal')  # Equal aspect ratio ensures pie is drawn as a circle
                plt.title(f"Cell Type Distribution at Scale {scale:.2f} (excluding background)")
                plt.show()
        
        # Store results
        results[scale] = {
            'class_counts': class_counts,
            'prediction': pred,
            'non_bg_percentage': non_bg_percentage
        }
    
    # Print summary and recommendation
    print("\n" + "="*50)
    print("SCALE COMPARISON SUMMARY")
    print("="*50)
    
    for scale, result in results.items():
        print(f"\nScale {scale:.2f}:")
        for cls, count in result['class_counts'].items():
            percentage = (count / result['prediction'].size) * 100
            print(f"  {cls}: {percentage:.2f}%")
    
    if best_scale is not None:
        print("\n" + "="*50)
        print(f"RECOMMENDED SCALE: {best_scale:.2f}")
        print(f"This scale provided the best cell detection with {best_non_bg_percentage:.2f}% non-background pixels")
        print("="*50)
    
    return results, best_scale

# Try with different scales on a TMA image
# def analyze_tma_sample(tma_folder, model, device, scales=[1.0, 0.5, 0.25, 0.1]):
def analyze_tma_sample(tma_folder, model, device, scales=[1.0, 0.5, 0.25, 0.1]):
    """Analyze a TMA sample with different scales"""
    # Path to high-resolution image
    img_path = os.path.join(tma_folder, "tissue_hires_image.png")
    
    if os.path.exists(img_path):
        print(f"Analyzing TMA sample: {os.path.basename(tma_folder)}")
        scale_results, best_scale = test_tma_with_different_scales(
            model, 
            img_path, 
            device, 
            scales=scales
        )
        return {
            'tma_folder': tma_folder,
            'best_scale': best_scale,
            'results': scale_results
        }
    else:
        print(f"Image not found: {img_path}")
        return None

In [None]:
import os
from pathlib import Path
import matplotlib.pyplot as plt

# Ensure plots are displayed in the notebook
%matplotlib inline

# Find all TMA folders
tma_folders = []
for folder in os.listdir():
    if folder.startswith("TMA_") and os.path.isdir(folder):
        tma_folders.append(folder)

if not tma_folders:
    print("No TMA folders found in the current directory")
else:
    print(f"Found {len(tma_folders)} TMA folders:")
    for folder in tma_folders:
        print(f"  - {folder}")
    
    # Analyze each TMA folder
    all_results = {}
    
    for tma_folder in tma_folders:
        print(f"\n{'='*50}")
        print(f"ANALYZING {tma_folder}")
        print(f"{'='*50}")
        
        # Analyze this TMA folder
        result = analyze_tma_sample(
            tma_folder=tma_folder,
            model=trained_model,
            device=device,
            scales=[0.5, 0.25, 0.1]  # Use fewer scales to save time
        )
        
        if result:
            all_results[tma_folder] = result
    
    # Print summary of best scales
    print("\n\n" + "="*50)
    print("SUMMARY OF BEST SCALES FOR ALL TMA FOLDERS")
    print("="*50)
    
    for tma_folder, result in all_results.items():
        best_scale = result['best_scale']
        if best_scale is not None:
            best_result = result['results'][best_scale]
            non_bg_percent = best_result['non_bg_percentage']
            print(f"{tma_folder}: Best scale = {best_scale:.2f} (Non-background: {non_bg_percent:.2f}%)")
            
            # Print cell type breakdown at the best scale
            for cls, count in best_result['class_counts'].items():
                if cls != 'Background':
                    percentage = (count / best_result['prediction'].size) * 100
                    if percentage > 0:
                        print(f"  {cls}: {percentage:.2f}%")

In [None]:
def analyze_svs_with_model(model, svs_path, device, patch_size=(984, 984), level=0):
    """
    Extract a patch from an SVS file and analyze it with the segmentation model.
    Will display results directly in the notebook.
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import torch
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    import cv2
    import openslide
    from matplotlib.patches import Patch
    
    print(f"Processing {svs_path}...")
    
    try:
        # Open the slide
        slide = openslide.OpenSlide(str(svs_path))
        
        # Get slide information
        dimensions = slide.level_dimensions
        downsamplings = slide.level_downsamples
        
        print(f"\nSlide Information:")
        print(f"Number of levels: {slide.level_count}")
        
        for i in range(slide.level_count):
            level_dim = dimensions[i]
            level_downsample = downsamplings[i]
            print(f"  Level {i}: {level_dim[0]} x {level_dim[1]}, downsample: {level_downsample:.2f}x")
        
        # If level is too high, adjust it
        if level >= slide.level_count:
            level = slide.level_count - 1
            print(f"Requested level too high, using level {level} instead")
            
        # Get level dimensions
        level_dims = dimensions[level]
        
        # Calculate the region to extract (center of the slide)
        x = (level_dims[0] - patch_size[0]) // 2
        y = (level_dims[1] - patch_size[1]) // 2
        
        # Ensure coordinates are non-negative
        x, y = max(0, x), max(0, y)
        
        # Read the region
        region = slide.read_region(
            (x * int(downsamplings[level]), y * int(downsamplings[level])), 
            level, 
            patch_size
        )
        
        # Convert to RGB (remove alpha channel)
        patch = np.array(region.convert('RGB'))
        
        # Display the extracted patch
        plt.figure(figsize=(10, 10))
        plt.imshow(patch)
        plt.title(f"Extracted Patch - Level {level}")
        plt.axis('off')
        plt.show()  # Explicitly show the plot
        
        # Close the slide
        slide.close()
        
        # Make dimensions divisible by 32 (required by U-Net)
        h, w = patch.shape[:2]
        new_h = ((h + 31) // 32) * 32
        new_w = ((w + 31) // 32) * 32
        
        # Apply transforms
        transform = A.Compose([
            A.PadIfNeeded(min_height=new_h, min_width=new_w, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        print("Applying model...")
        
        transformed = transform(image=patch)
        image_tensor = transformed['image'].unsqueeze(0).to(device)
        
        # Get model output
        model.eval()
        with torch.no_grad():
            output = model(image_tensor)
        
        # Get probabilities and predictions
        probs = torch.softmax(output, dim=1).cpu().numpy()[0]
        pred = np.argmax(probs, axis=0)[:h, :w]
        
        # Define cell type colors and names
        colors = {
            0: [0, 0, 0],       # Background: black
            1: [255, 0, 0],     # Epithelial: red
            2: [0, 255, 0],     # Smooth muscle: green
            3: [0, 0, 255],     # Lymphocyte: blue
            4: [255, 255, 0]    # Myeloid: yellow
        }
        
        class_names = ['Background', 'Epithelial', 'Smooth Muscle', 'Lymphocyte', 'Myeloid']
        
        # Count predictions
        unique, counts = np.unique(pred, return_counts=True)
        class_counts = {class_names[i]: 0 for i in range(5)}
        for cls, count in zip(unique, counts):
            class_counts[class_names[cls]] = count
        
        # Print result
        print("\nClass distribution:")
        for cls, count in class_counts.items():
            percentage = (count / pred.size) * 100
            print(f"{cls}: {count} pixels ({percentage:.2f}%)")
        
        # Visualize prediction
        vis_mask = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
        for i in range(5):
            vis_mask[pred == i] = colors[i]
        
        # Create overlay image
        alpha = 0.5  # Transparency factor
        overlay = patch.copy()
        
        # For non-background classes, apply color overlay
        for i in range(1, 5):  # Skip background (0)
            mask = pred == i
            if np.any(mask):
                overlay[mask] = alpha * overlay[mask] + (1 - alpha) * np.array(colors[i])
        
        # Display all three visualizations together
        plt.figure(figsize=(20, 10))
        
        # Original image
        plt.subplot(1, 3, 1)
        plt.imshow(patch)
        plt.title("Original Image")
        plt.axis('off')
        
        # Pure segmentation
        plt.subplot(1, 3, 2)
        plt.imshow(vis_mask)
        plt.title("Cell Type Segmentation")
        plt.axis('off')
        
        # Overlay with legend
        plt.subplot(1, 3, 3)
        plt.imshow(overlay)
        plt.title("Segmentation Overlay")
        plt.axis('off')
        
        # Create legend elements
        legend_elements = []
        for i in range(1, 5):  # Skip background
            if i in unique:
                # Convert RGB [0-255] to matplotlib's [0-1] scale
                color_rgb = [c/255 for c in colors[i]]
                legend_elements.append(
                    Patch(facecolor=color_rgb, 
                          label=f"{class_names[i]} ({class_counts[class_names[i]]/pred.size*100:.1f}%)")
                )
        
        # Add legend to the figure
        plt.figlegend(handles=legend_elements, loc='lower center', ncol=4, 
                      bbox_to_anchor=(0.5, 0.01), frameon=True, fancybox=True, shadow=True)
        
        plt.tight_layout()
        plt.subplots_adjust(bottom=0.15)  # Make room for the legend
        plt.show()  # Explicitly show the plot
        
        # Create a pie chart of cell type distribution (excluding background)
        if any(class_counts[class_names[i]] > 0 for i in range(1, 5)):
            plt.figure(figsize=(8, 8))
            
            # Prepare data for pie chart (excluding background)
            non_bg_counts = {k: v for k, v in class_counts.items() if k != 'Background' and v > 0}
            
            if non_bg_counts:  # Only create pie chart if there are non-background classes
                labels = list(non_bg_counts.keys())
                sizes = list(non_bg_counts.values())
                
                # Corresponding colors for the pie chart
                pie_colors = []
                for label in labels:
                    idx = class_names.index(label)
                    color_rgb = [c/255 for c in colors[idx]]
                    pie_colors.append(color_rgb)
                
                # Create pie chart
                plt.pie(sizes, labels=labels, colors=pie_colors, autopct='%1.1f%%', 
                        shadow=True, startangle=90)
                plt.axis('equal')  # Equal aspect ratio ensures pie is drawn as a circle
                plt.title("Cell Type Distribution (excluding background)")
                plt.show()  # Explicitly show the plot
        
        return {
            'patch': patch,
            'prediction': pred,
            'class_counts': class_counts
        }
    
    except Exception as e:
        print(f"Error processing slide: {e}")
        import traceback
        traceback.print_exc()
        return None

In [None]:
import os
from pathlib import Path
import matplotlib.pyplot as plt

# Ensure plots display inline in the notebook
%matplotlib inline

# Find all SVS files in the folder
svs_folder = "H&E_Svs"
svs_files = list(Path(svs_folder).glob("*.svs"))

if not svs_files:
    print(f"No SVS files found in {svs_folder}")
else:
    print(f"Found {len(svs_files)} SVS files:")
    for i, file in enumerate(svs_files):
        print(f"{i+1}. {file.name}")
    
    # Process each SVS file
    results = {}
    for svs_file in svs_files:
        print(f"\n{'='*50}")
        print(f"Analyzing {svs_file.name}...")
        print(f"{'='*50}")
        
        # Try level 0 first
        analysis = analyze_svs_with_model(
            model=trained_model,
            svs_path=svs_file,
            device=device,
            level=0  # Start with level 0
        )
        
        if analysis:
            results[svs_file.name] = analysis
    
    # Print summary of results
    print("\nSummary of cell type distributions across all slides:")
    for slide_name, result in results.items():
        print(f"\n{slide_name}:")
        for cell_type, count in result['class_counts'].items():
            percentage = (count / result['prediction'].size) * 100
            print(f"  {cell_type}: {percentage:.2f}%")

## 14. Conclusion

In this notebook, we have:
1. Implemented a complete pipeline for segmenting cell types in H&E histology images
2. Trained a UNet model with ResNet34 backbone on the SegPath dataset
3. Developed methods to apply the trained model to Visium HD spatial transcriptomics images
4. Created tools for comparing segmentation-based cell type identification with gene expression data

The trained model can be used to analyze the spatial distribution of cell types in tissue and correlate morphological features with gene expression patterns.

In [None]:
# Save the final trained model
if 'trained_model' in locals():
    final_model_path = os.path.join(output_dir, 'final_segpath_model.pth')
    torch.save(trained_model.state_dict(), final_model_path)
    print(f"Final model saved to: {final_model_path}")
    
print("Notebook execution completed!")