# Steatosis Segmentation using U-Net
Training a U-Net model to segment liver steatosis using the provided Training and Validation datasets.
This notebook implements the training pipeline including data loading, model definition, training loop, and evaluation.

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 20
INPUT_SIZE = (256, 256)
CHECKPOINT_DIR = 'checkpoints_steatosis'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Data paths (Modify these if your folder structure is different)
TRAIN_DIR = '/content/SteatosisU-UNet/train'
VAL_DIR = '/content/SteatosisU-UNet/val'

In [None]:
class SteatosisDataset(Dataset):
    def __init__(self, root_dir, transform=None, normalization_method=None):
        """
        Args:
            root_dir (string): Directory with 'image' and 'manual' subdirs.
            transform (callable, optional): Optional transform to be applied on a sample.
            normalization_method (str): None, 'reinhard', or 'macenko'.
        """
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, 'image')
        # Check correct folder
        if os.path.exists(os.path.join(root_dir, 'manual_py')):
            self.mask_dir = os.path.join(root_dir, 'manual_py')
        else:
            self.mask_dir = os.path.join(root_dir, 'manual')
            print(f"Warning: 'manual_py' not found in {root_dir}. Using 'manual'.")

        self.transform = transform
        self.normalization_method = normalization_method
        
        # Load file lists
        valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif')
        self.images = sorted([f for f in os.listdir(self.image_dir) if f.lower().endswith(valid_extensions)])
        self.masks = sorted([f for f in os.listdir(self.mask_dir) if f.lower().endswith(valid_extensions)])
        
        if len(self.images) != len(self.masks):
            print(f"Warning: Mismatch between images and masks in {root_dir}")
            
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx] 
        
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        # Load
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        
        # Apply Normalization BEFORE Transforms (requires numpy)
        if self.normalization_method:
            img_np = np.array(image)
            try:
                norm_np = img_np
                if self.normalization_method == 'reinhard':
                    norm_np = reinhard_normalize(img_np)
                elif self.normalization_method == 'macenko':
                     # Placeholder if macenko is still available, otherwise ignore
                     # norm_np = macenko_normalize(img_np) 
                     pass
                     
                image = Image.fromarray(norm_np)
            except Exception as e:
                pass

        # Transforms
        # Resize image
        resize_img = transforms.Resize(INPUT_SIZE, interpolation=transforms.InterpolationMode.BILINEAR)
        image = resize_img(image)
        
        # Resize mask
        # IMPORTANT: Nearest Neighbor for masks
        resize_mask = transforms.Resize(INPUT_SIZE, interpolation=transforms.InterpolationMode.NEAREST)
        mask = resize_mask(mask)
        
        # ToTensor
        to_tensor = transforms.ToTensor()
        image = to_tensor(image)
        mask_t = to_tensor(mask)
        
        # Binarize Mask (Hard threshold 0)
        # Assuming manual_py is 0/1 or 0/255.
        mask_t = (mask_t > 0).float()
        
        return image, mask_t

In [None]:
# Reinhard Normalization Implementation
def reinhard_normalize(source_img, target_mu=None, target_sigma=None):
    """
    Normalizes an image using Reinhard's method.
    source_img: Input image (PIL or numpy array, RGB, uint8)
    target_mu, target_sigma: Statistics of target image (in LAB space)
    Reference: Reinhard et al. 2001
    """
    source_img = np.array(source_img)
    
    # Defaults (from a reference "good" slide)
    if target_mu is None:
        target_mu = np.array([8.6323, -0.1150, 0.0387]) # L, A, B means
    if target_sigma is None:
        target_sigma = np.array([0.5750, 0.1040, 0.0136]) # L, A, B stds

    # RGB to LAB (using simple cv2-like conversion or skimage)
    # Using simplified conversion for dependency-free numpy
    
    # 1. RGB to LMS
    source_img = source_img.astype(float)
    # Add small epsilon to avoid log(0)
    source_img = np.maximum(source_img, 1e-5)
    
    M_rgb_lms = np.array([
        [0.3811, 0.5783, 0.0402],
        [0.1967, 0.7244, 0.0782],
        [0.0241, 0.1288, 0.8444]
    ])
    
    lms = np.dot(source_img.reshape(-1, 3), M_rgb_lms.T)
    lms = np.log10(lms)
    
    # 2. LMS to LAB
    M_lms_lab = np.array([
        [1/np.sqrt(3), 0, 0],
        [0, 1/np.sqrt(6), 0],
        [0, 0, 1/np.sqrt(2)]
    ]) @ np.array([
        [1, 1, 1],
        [1, 1, -2],
        [1, -1, 0]
    ])
    
    lab = np.dot(lms, M_lms_lab.T)
    
    # 3. Statistics
    mu = np.mean(lab, axis=0)
    sigma = np.std(lab, axis=0)
    
    # 4. Normalize
    # Clip very small sigma to avoid extreme scaling
    sigma = np.maximum(sigma, 1e-5)
    
    lab_norm = (lab - mu) * (target_sigma / sigma) + target_mu
    
    # 5. LAB to LMS
    lms_norm = np.dot(lab_norm, np.linalg.inv(M_lms_lab).T)
    lms_norm = np.power(10, lms_norm)
    
    # 6. LMS to RGB
    rgb_norm = np.dot(lms_norm, np.linalg.inv(M_rgb_lms).T)
    rgb_norm = np.clip(rgb_norm, 0, 255).astype(np.uint8)
    
    return rgb_norm.reshape(source_img.shape[0], source_img.shape[1], 3)

def visualize_reinhard_demo(dataset, num_samples=5):
    """
    Visualizes original vs reinhard normalized images.
    """
    indices = np.random.choice(len(dataset), size=min(len(dataset), num_samples), replace=False)
    
    plt.figure(figsize=(10, 3 * num_samples))
    
    for i, idx in enumerate(indices):
        img_path = os.path.join(dataset.image_dir, dataset.images[idx])
        original = Image.open(img_path).convert("RGB")
        original_np = np.array(original)
        
        try:
            # Force reinhard regardless of dataset setting
            normalized_np = reinhard_normalize(original_np)
            
            plt.subplot(num_samples, 2, i*2 + 1)
            plt.imshow(original_np)
            plt.title(f"Original {idx}")
            plt.axis('off')
            
            plt.subplot(num_samples, 2, i*2 + 2)
            plt.imshow(normalized_np)
            plt.title("Reinhard Normalized")
            plt.axis('off')
            
        except Exception as e:
            print(f"Error normalizing {idx}: {e}")
            
    plt.tight_layout()
    plt.show()

print("Reinhard Normalizer Ready.")

In [None]:
# 1. Visualization of Reinhard Normalization (Before Training)
print("Generating Reinhard Normalization Previews...")
# Create a temporary dataset just for visualization
temp_viz_dataset_train = SteatosisDataset(TRAIN_DIR, normalization_method=None)
temp_viz_dataset_val = SteatosisDataset(VAL_DIR, normalization_method=None)

# Save images directory
VIZ_OUTPUT_DIR = 'normalization_previews_reinhard'
os.makedirs(VIZ_OUTPUT_DIR, exist_ok=True)

# 20 samples from Train
print("Visualizing 20 Train samples...")
subset_indices = np.random.choice(len(temp_viz_dataset_train), 20, replace=False)

fig, axs = plt.subplots(20, 2, figsize=(8, 60))
for i, idx in enumerate(subset_indices):
    img_name = temp_viz_dataset_train.images[idx]
    img_path = os.path.join(temp_viz_dataset_train.image_dir, img_name)
    
    orig = np.array(Image.open(img_path).convert("RGB"))
    norm = reinhard_normalize(orig)
    
    # Plot
    axs[i, 0].imshow(orig)
    axs[i, 0].set_title(f"Train Orig: {img_name}")
    axs[i, 0].axis('off')
    
    axs[i, 1].imshow(norm)
    axs[i, 1].set_title("Reinhard Norm")
    axs[i, 1].axis('off')
    
    # Save file
    Image.fromarray(norm).save(os.path.join(VIZ_OUTPUT_DIR, f"norm_train_{img_name}"))

plt.tight_layout()
plt.show()

# All samples from Val
print(f"Visualizing ALL ({len(temp_viz_dataset_val)}) Val samples...")
fig_v, axs_v = plt.subplots(len(temp_viz_dataset_val), 2, figsize=(8, 3 * len(temp_viz_dataset_val)))
if len(temp_viz_dataset_val) == 1: axs_v = np.expand_dims(axs_v, 0)

for i in range(len(temp_viz_dataset_val)):
    img_name = temp_viz_dataset_val.images[i]
    img_path = os.path.join(temp_viz_dataset_val.image_dir, img_name)
    
    orig = np.array(Image.open(img_path).convert("RGB"))
    norm = reinhard_normalize(orig)
    
    # Plot
    axs_v[i, 0].imshow(orig)
    axs_v[i, 0].set_title(f"Val Orig: {img_name}")
    axs_v[i, 0].axis('off')
    
    axs_v[i, 1].imshow(norm)
    axs_v[i, 1].set_title("Reinhard Norm")
    axs_v[i, 1].axis('off')
    
    # Save file
    Image.fromarray(norm).save(os.path.join(VIZ_OUTPUT_DIR, f"norm_val_{img_name}"))

plt.tight_layout()
plt.show()
print(f"Normalized images saved to {VIZ_OUTPUT_DIR}")

In [None]:
# 2. DEFINITION OF TRAINING FUNCTION
# Wrapping the training loop in a function to run it twice (Baseline vs Reinhard)
def run_training_experiment(experiment_name, train_dataset, val_dataset):
    print(f"\n{'='*20} RUNNING EXPERIMENT: {experiment_name} {'='*20}")
    
    # Loaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    # Init Model & Ops
    model = UNet(in_channels=3, out_channels=1, init_filters=32, depth=4).to(device)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Metrics Storage
    history = {
        'train_loss': [], 'val_loss': [],
        'val_dice_strict_avg': [], 'val_dice_strict_global': []
    }
    best_metric = 0.0
    
    save_dir = os.path.join(CHECKPOINT_DIR, experiment_name)
    os.makedirs(save_dir, exist_ok=True)
    
    for epoch in range(NUM_EPOCHS):
        # TRAIN
        model.train()
        running_loss = 0.0
        pbar = tqdm(train_loader, desc=f'[{experiment_name}] Ep {epoch+1}/{NUM_EPOCHS}', leave=False)
        
        for images, masks in pbar:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * images.size(0)
            
        epoch_train_loss = running_loss / len(train_dataset)
        
        # VALIDATION
        model.eval()
        running_val_loss = 0.0
        val_stats = {'sum_dice_strict': 0, 'total_int': 0, 'total_union': 0, 'n_samples': 0}

        with torch.no_grad():
            for images, masks in val_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                running_val_loss += loss.item() * images.size(0)
                
                # Metrics
                batch_s = compute_batch_stats(outputs, masks)
                for k in ['sum_dice_strict', 'total_int', 'total_union', 'n_samples']:
                    val_stats[k] += batch_s[k]

        epoch_val_loss = running_val_loss / len(val_dataset)
        v_strict_avg = val_stats['sum_dice_strict'] / val_stats['n_samples']
        v_strict_glob = (2. * val_stats['total_int']) / (val_stats['total_union'] + 1e-7) if val_stats['total_union'] > 0 else 0
        
        history['train_loss'].append(epoch_train_loss)
        history['val_loss'].append(epoch_val_loss)
        history['val_dice_strict_avg'].append(v_strict_avg)
        history['val_dice_strict_global'].append(v_strict_glob)
        
        print(f"[{experiment_name}] Ep {epoch+1} | Loss: {epoch_train_loss:.4f} / {epoch_val_loss:.4f} | Val Dice Strict: {v_strict_avg:.4f}")
        
        if v_strict_avg > best_metric:
            best_metric = v_strict_avg
            torch.save(model.state_dict(), os.path.join(save_dir, 'best_model.pth'))
    
    # Save History
    with open(os.path.join(save_dir, 'history.json'), 'w') as f:
        json.dump(history, f)
        
    return history

# --------------------------------------------------------
# 3. EXECUTE BASELINE (NO NORMALIZATION)
# --------------------------------------------------------
ds_train_base = SteatosisDataset(TRAIN_DIR, normalization_method=None)
ds_val_base = SteatosisDataset(VAL_DIR, normalization_method=None)

history_baseline = run_training_experiment("BASELINE", ds_train_base, ds_val_base)

# --------------------------------------------------------
# 4. EXECUTE REINHARD (NORMALIZED)
# --------------------------------------------------------
ds_train_reinhard = SteatosisDataset(TRAIN_DIR, normalization_method='reinhard')
ds_val_reinhard = SteatosisDataset(VAL_DIR, normalization_method='reinhard')

history_reinhard = run_training_experiment("REINHARD", ds_train_reinhard, ds_val_reinhard)

# --------------------------------------------------------
# 5. COMPARISON PLOT
# --------------------------------------------------------
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history_baseline['val_dice_strict_avg'], label='Baseline', marker='.')
plt.plot(history_reinhard['val_dice_strict_avg'], label='Reinhard', marker='.')
plt.title("Validation Dice Strict (Avg)")
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history_baseline['val_loss'], label='Baseline', marker='.')
plt.plot(history_reinhard['val_loss'], label='Reinhard', marker='.')
plt.title("Validation Loss")
plt.legend()
plt.grid(True)

plt.show()

In [None]:
class DoubleConv(nn.Module):
    """Applies two consecutive conv-batchnorm-relu layers"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self,
                 in_channels=3,
                 out_channels=1,
                 init_filters=64,
                 depth=4,
                 bilinear=True):
        super(UNet, self).__init__()
        self.depth = depth
        self.down_layers = nn.ModuleList()
        self.up_layers = nn.ModuleList()
        self.pool = nn.MaxPool2d(2)

        # Encoder
        filters = init_filters
        for d in range(depth):
            conv = DoubleConv(in_channels, filters)
            self.down_layers.append(conv)
            in_channels = filters
            filters *= 2

        # Bottleneck
        self.bottleneck = DoubleConv(in_channels, filters)

        # Decoder
        for d in range(depth):
            filters //= 2
            if bilinear:
                up = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                    nn.Conv2d(filters * 2, filters, kernel_size=1)
                )
            else:
                up = nn.ConvTranspose2d(filters * 2, filters, kernel_size=2, stride=2)
            self.up_layers.append(nn.ModuleDict({
                'up': up,
                'conv': DoubleConv(filters * 2, filters)
            }))

        # Output layer
        self.out_conv = nn.Conv2d(init_filters, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.down_layers:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        for i in range(self.depth):
            skip = skip_connections[-(i+1)]
            up = self.up_layers[i]['up'](x)
            if up.size() != skip.size():
                # Resize in case of odd size mismatch
                up = F.interpolate(up, size=skip.shape[2:])
            x = torch.cat([skip, up], dim=1)
            x = self.up_layers[i]['conv'](x)

        return self.out_conv(x)

In [None]:
# Metric Functions
def compute_batch_stats(pred_logits, target_mask):
    """
    Calculates statistics for a batch to compute:
    1. Dice Standard (Mean per Image)
    2. Dice Strict (Mean per Image)
    3. Global Stats (Intersection & Union) for Batch-Based/Global Dice
    """
    # Sigmoid & Binarization
    probs = torch.sigmoid(pred_logits)
    pred = (probs > 0.5).float()
    
    # Flatten: (B, C, H, W) -> (B, -1)
    pred_flat = pred.view(pred.size(0), -1)
    target_flat = target_mask.view(target_mask.size(0), -1)
    
    # Intersection & Sum per image
    intersection = (pred_flat * target_flat).sum(dim=1)
    union_raw = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
    
    # --- 1. Dice Standard Per Image (with smooth) ---
    smooth = 1e-5
    dice_std_img = (2. * intersection + smooth) / (union_raw + smooth)
    
    # --- 2. Dice Strict Per Image (no smooth, handle empty) ---
    # If both empty (union=0) -> 1.0, Else calc dice.
    dice_strict_img = torch.zeros_like(intersection)
    
    is_empty = (union_raw == 0)
    # Case both empty -> 1
    dice_strict_img[is_empty] = 1.0 
    # Case not empty -> 2*I / U
    if (~is_empty).any():
        dice_strict_img[~is_empty] = (2. * intersection[~is_empty]) / union_raw[~is_empty]

    # --- 3. For Global/Batch Calculation ---
    total_intersection = intersection.sum().item()
    total_union = union_raw.sum().item()
    
    return {
        'sum_dice_std': dice_std_img.sum().item(),
        'sum_dice_strict': dice_strict_img.sum().item(),
        'total_int': total_intersection,
        'total_union': total_union,
        'n_samples': pred.size(0)
    }

# Initialize Model
model = UNet(in_channels=3, out_channels=1, init_filters=32, depth=4).to(device)

# Loss Function
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Model initialized with detailed metric tracking.")