# Steatosis Segmentation using U-Net
Training a U-Net model to segment liver steatosis using the provided Training and Validation datasets.
This notebook implements the training pipeline including data loading, model definition, training loop, and evaluation.

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 20
INPUT_SIZE = (256, 256)
CHECKPOINT_DIR = 'checkpoints_steatosis'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Data paths (Modify these if your folder structure is different)
TRAIN_DIR = '/content/SteatosisU-UNet/train'
VAL_DIR = '/content/SteatosisU-UNet/val'

In [None]:
class SteatosisDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with 'image' and 'manual' subdirs.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, 'image')
        # MODIFICA: Usiamo la cartella corretta per il training
        self.mask_dir = os.path.join(root_dir, 'manual_py')
        self.transform = transform
        
        # Load file lists
        # We filter for common image extensions
        valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif')
        self.images = sorted([f for f in os.listdir(self.image_dir) if f.lower().endswith(valid_extensions)])
        self.masks = sorted([f for f in os.listdir(self.mask_dir) if f.lower().endswith(valid_extensions)])
        
        # Basic check
        if len(self.images) != len(self.masks):
            print(f"Warning: Number of images ({len(self.images)}) and masks ({len(self.masks)}) in {root_dir} do not match!")
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx] 
        
        # Construct full paths
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        # Load image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L") # Grayscale for mask
        
        # Apply transforms
        
        # Resize image (Bilineare va bene per le foto)
        resize_img = transforms.Resize(INPUT_SIZE, interpolation=transforms.InterpolationMode.BILINEAR)
        image = resize_img(image)
        
        # Resize mask (NEAREST è fondamentale per le maschere per mantenere 0 e 1 puri)
        resize_mask = transforms.Resize(INPUT_SIZE, interpolation=transforms.InterpolationMode.NEAREST)
        mask = resize_mask(mask)
        
        to_tensor = transforms.ToTensor()
        
        image = to_tensor(image)
        mask = to_tensor(mask)
        
        # Con manual_py i valori sono spesso molto bassi (0 e 1 su scala 255 vengono caricati quasi neri).
        # ToTensor normalizza in [0, 1]. Se il pixel era 1 (classe steatosi), diventa 1/255.
        # Se invece manual_py è salvata come 0 e 255, diventa 0 e 1.
        # Per sicurezza assoluta, binarizziamo qualunque cosa > 0.
        mask = (mask > 0).float()
        
        return image, mask

In [None]:
# Create Dataset instances
train_dataset = SteatosisDataset(TRAIN_DIR)
val_dataset = SteatosisDataset(VAL_DIR)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training set: {len(train_dataset)} images")
print(f"Validation set: {len(val_dataset)} images")

# Sanity check: Visualize one sample
temp_img, temp_mask = train_dataset[0]
print(f"Image tensor shape: {temp_img.shape}")
print(f"Mask tensor shape: {temp_mask.shape}")
print(f"Mask unique values: {torch.unique(temp_mask)}")

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(temp_img.permute(1, 2, 0))
plt.title("Sample Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(temp_mask.squeeze(), cmap='gray')
plt.title("Sample Mask")
plt.axis('off')
plt.show()

In [None]:
class DoubleConv(nn.Module):
    """Applies two consecutive conv-batchnorm-relu layers"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self,
                 in_channels=3,
                 out_channels=1,
                 init_filters=64,
                 depth=4,
                 bilinear=True):
        super(UNet, self).__init__()
        self.depth = depth
        self.down_layers = nn.ModuleList()
        self.up_layers = nn.ModuleList()
        self.pool = nn.MaxPool2d(2)

        # Encoder
        filters = init_filters
        for d in range(depth):
            conv = DoubleConv(in_channels, filters)
            self.down_layers.append(conv)
            in_channels = filters
            filters *= 2

        # Bottleneck
        self.bottleneck = DoubleConv(in_channels, filters)

        # Decoder
        for d in range(depth):
            filters //= 2
            if bilinear:
                up = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                    nn.Conv2d(filters * 2, filters, kernel_size=1)
                )
            else:
                up = nn.ConvTranspose2d(filters * 2, filters, kernel_size=2, stride=2)
            self.up_layers.append(nn.ModuleDict({
                'up': up,
                'conv': DoubleConv(filters * 2, filters)
            }))

        # Output layer
        self.out_conv = nn.Conv2d(init_filters, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.down_layers:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        for i in range(self.depth):
            skip = skip_connections[-(i+1)]
            up = self.up_layers[i]['up'](x)
            if up.size() != skip.size():
                # Resize in case of odd size mismatch
                up = F.interpolate(up, size=skip.shape[2:])
            x = torch.cat([skip, up], dim=1)
            x = self.up_layers[i]['conv'](x)

        return self.out_conv(x)

In [None]:
# Metric Functions
def compute_batch_stats(pred_logits, target_mask):
    """
    Calculates statistics for a batch to compute:
    1. Dice Standard (Mean per Image)
    2. Dice Strict (Mean per Image)
    3. Global Stats (Intersection & Union) for Batch-Based/Global Dice
    """
    # Sigmoid & Binarization
    probs = torch.sigmoid(pred_logits)
    pred = (probs > 0.5).float()
    
    # Flatten: (B, C, H, W) -> (B, -1)
    pred_flat = pred.view(pred.size(0), -1)
    target_flat = target_mask.view(target_mask.size(0), -1)
    
    # Intersection & Sum per image
    intersection = (pred_flat * target_flat).sum(dim=1)
    union_raw = pred_flat.sum(dim=1) + target_flat.sum(dim=1)
    
    # --- 1. Dice Standard Per Image (with smooth) ---
    smooth = 1e-5
    dice_std_img = (2. * intersection + smooth) / (union_raw + smooth)
    
    # --- 2. Dice Strict Per Image (no smooth, handle empty) ---
    # If both empty (union=0) -> 1.0, Else calc dice.
    dice_strict_img = torch.zeros_like(intersection)
    
    is_empty = (union_raw == 0)
    # Case both empty -> 1
    dice_strict_img[is_empty] = 1.0 
    # Case not empty -> 2*I / U
    if (~is_empty).any():
        dice_strict_img[~is_empty] = (2. * intersection[~is_empty]) / union_raw[~is_empty]

    # --- 3. For Global/Batch Calculation ---
    total_intersection = intersection.sum().item()
    total_union = union_raw.sum().item()
    
    return {
        'sum_dice_std': dice_std_img.sum().item(),
        'sum_dice_strict': dice_strict_img.sum().item(),
        'total_int': total_intersection,
        'total_union': total_union,
        'n_samples': pred.size(0)
    }

# Initialize Model
model = UNet(in_channels=3, out_channels=1, init_filters=32, depth=4).to(device)

# Loss Function
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Model initialized with detailed metric tracking.")

In [None]:
# History storage
history = {
    'train_loss': [], 'val_loss': [],
    'train_dice_std_avg': [], 'val_dice_std_avg': [],
    'train_dice_strict_avg': [], 'val_dice_strict_avg': [],
    'train_dice_std_global': [], 'val_dice_std_global': [],
    'train_dice_strict_global': [], 'val_dice_strict_global': []
}

best_metric = 0.0

print(f"Starting training for {NUM_EPOCHS} epochs...")

for epoch in range(NUM_EPOCHS):
    # --- TRAINING ---
    model.train()
    running_loss = 0.0
    
    # Accumulators for metrics
    train_stats = {'sum_dice_std': 0, 'sum_dice_strict': 0, 'total_int': 0, 'total_union': 0, 'n_samples': 0}
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} - Train', leave=False)
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        
        # Calculate stats for this batch
        with torch.no_grad():
            batch_s = compute_batch_stats(outputs, masks)
            for k in train_stats:
                train_stats[k] += batch_s[k]
                
    # End of Train Epoch Calculations
    epoch_train_loss = running_loss / len(train_dataset)
    
    # Train Metrics
    t_dice_std_avg = train_stats['sum_dice_std'] / train_stats['n_samples']
    t_dice_strict_avg = train_stats['sum_dice_strict'] / train_stats['n_samples']
    
    # Global Train Metrics
    smooth = 1e-5
    t_dice_std_global = (2. * train_stats['total_int'] + smooth) / (train_stats['total_union'] + smooth)
    
    if train_stats['total_union'] == 0:
         # If the whole dataset was empty? Unlikely, but let's handle.
         # If int is also 0 -> 1.
         t_dice_strict_global = 1.0 if train_stats['total_int'] == 0 else 0.0
    else:
         t_dice_strict_global = (2. * train_stats['total_int']) / train_stats['total_union']

    # --- VALIDATION ---
    model.eval()
    running_val_loss = 0.0
    val_stats = {'sum_dice_std': 0, 'sum_dice_strict': 0, 'total_int': 0, 'total_union': 0, 'n_samples': 0}
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} - Val', leave=False):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_val_loss += loss.item() * images.size(0)
            
            batch_s = compute_batch_stats(outputs, masks)
            for k in val_stats:
                val_stats[k] += batch_s[k]

    # End of Val Epoch Calculations
    epoch_val_loss = running_val_loss / len(val_dataset)
    
    # Val Metrics
    v_dice_std_avg = val_stats['sum_dice_std'] / val_stats['n_samples']
    v_dice_strict_avg = val_stats['sum_dice_strict'] / val_stats['n_samples']
    
    # Global Val Metrics
    v_dice_std_global = (2. * val_stats['total_int'] + smooth) / (val_stats['total_union'] + smooth)
    
    if val_stats['total_union'] == 0:
         v_dice_strict_global = 1.0 if val_stats['total_int'] == 0 else 0.0
    else:
         v_dice_strict_global = (2. * val_stats['total_int']) / val_stats['total_union']

    # Update History
    history['train_loss'].append(epoch_train_loss)
    history['val_loss'].append(epoch_val_loss)
    history['train_dice_std_avg'].append(t_dice_std_avg)
    history['val_dice_std_avg'].append(v_dice_std_avg)
    history['train_dice_strict_avg'].append(t_dice_strict_avg)
    history['val_dice_strict_avg'].append(v_dice_strict_avg)
    history['train_dice_std_global'].append(t_dice_std_global)
    history['val_dice_std_global'].append(v_dice_std_global)
    history['train_dice_strict_global'].append(t_dice_strict_global)
    history['val_dice_strict_global'].append(v_dice_strict_global)

    # Print Report
    print(f"Epoch {epoch+1:02d} | Loss: T={epoch_train_loss:.4f} V={epoch_val_loss:.4f}")
    print(f"   [Per-Image]  Dice Strict: T={t_dice_strict_avg:.4f} V={v_dice_strict_avg:.4f} | Std: T={t_dice_std_avg:.4f} V={v_dice_std_avg:.4f}")
    print(f"   [Batch-Based] Dice Strict: T={t_dice_strict_global:.4f} V={v_dice_strict_global:.4f} | Std: T={t_dice_std_global:.4f} V={v_dice_std_global:.4f}")

    # Save Checkpoints
    current_metric = v_dice_strict_avg  # Use Per-Image Strict Dice as main criteria
    if current_metric > best_metric:
        best_metric = current_metric
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_model.pth'))
        print(f"   --> NEW BEST MODEL (Dice Strict Avg: {best_metric:.4f})")

    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f'model_epoch_{epoch+1}.pth'))

# Save final params
params = {
    "config": {"batch_size": BATCH_SIZE, "epochs": NUM_EPOCHS, "input_size": INPUT_SIZE},
    "results": {
        "best_val_dice_strict_avg": best_metric,
        "final_val_dice_strict_global": v_dice_strict_global
    }
}
with open(os.path.join(CHECKPOINT_DIR, 'training_results.json'), 'w') as f:
    json.dump(params, f, indent=4)

print("Training completed.")

In [None]:
# Plot Comprehensive Metrics
plt.figure(figsize=(18, 10))

# 1. Losses
plt.subplot(2, 2, 1)
plt.plot(history['train_loss'], label='Train Loss', marker='.')
plt.plot(history['val_loss'], label='Val Loss', marker='.')
plt.title('BCE Loss')
plt.grid(True)
plt.legend()

# 2. Strict Dice (Per Image - The main metric)
plt.subplot(2, 2, 2)
plt.plot(history['train_dice_strict_avg'], label='Train Strict (Avg)', marker='.')
plt.plot(history['val_dice_strict_avg'], label='Val Strict (Avg)', marker='.')
plt.title('Dice Strict (Per-Image Average)')
plt.grid(True)
plt.legend()

# 3. Global vs Average Comparison (Validation Only)
plt.subplot(2, 2, 3)
plt.plot(history['val_dice_strict_avg'], label='Val Strict (Avg)', marker='.', linestyle='--')
plt.plot(history['val_dice_strict_global'], label='Val Strict (Global)', marker='.')
plt.title('Validation: Average vs Global (Strict)')
plt.grid(True)
plt.legend()

# 4. Standard vs Strict Comparison (Validation Only)
plt.subplot(2, 2, 4)
plt.plot(history['val_dice_std_avg'], label='Val Standard (Avg)', marker='.', linestyle='--')
plt.plot(history['val_dice_strict_avg'], label='Val Strict (Avg)', marker='.')
plt.title('Validation: Standard vs Strict (Average)')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

# --- Visualization of Predictions (Best Model) ---
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, 'best_model.pth')))
model.eval()

# Get a batch
images, masks = next(iter(val_loader))
images = images.to(device)

with torch.no_grad():
    outputs = model(images)
    preds = torch.sigmoid(outputs) > 0.5 

# Plot
n_plot = min(3, images.size(0))
fig, axs = plt.subplots(n_plot, 3, figsize=(15, 5*n_plot))

images_np = images.cpu().permute(0, 2, 3, 1).numpy()
masks_np = masks.cpu().squeeze().numpy()
preds_np = preds.float().cpu().squeeze().numpy()

if n_plot == 1: axs = [axs] # Handle single case

for i in range(n_plot):
    # Original
    axs[i][0].imshow(images_np[i])
    axs[i][0].set_title(f"Image {i+1}")
    axs[i][0].axis('off')
    
    # Ground Truth
    axs[i][1].imshow(masks_np[i], cmap='gray', vmin=0, vmax=1)
    axs[i][1].set_title("Ground Truth (Manual)")
    axs[i][1].axis('off')
    
    # Prediction
    axs[i][2].imshow(preds_np[i], cmap='gray', vmin=0, vmax=1)
    axs[i][2].set_title("Prediction (U-Net)")
    axs[i][2].axis('off')
    
plt.tight_layout()
plt.show()

In [None]:
# -------------------------------------------------------------------------
# VISUALIZZAZIONE E SALVATAGGIO DI TUTTE LE PREDIZIONI DEL VALIDATION SET
# -------------------------------------------------------------------------

OUTPUT_DIR = 'val_predictions_output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# 1. Carichiamo il modello migliore assoluto
print("Caricamento del modello migliore...")
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, 'best_model.pth')))
model.eval()

# Liste per lo storage temporaneo dei dati da plottare
plot_data = []

print(f"Salvataggio maschere in: {os.path.abspath(OUTPUT_DIR)}")

with torch.no_grad():
    # Iteriamo direttamente sul dataset per avere accesso facile ai nomi dei file originali
    # Nota: val_dataset.images contiene la lista dei nomi dei file
    for i in range(len(val_dataset)):
        
        # Recuperiamo i dati dal dataset
        img_tensor, mask_tensor = val_dataset[i]
        filename = val_dataset.images[i]
        
        # Preparazione input per la rete (aggiungiamo dimensione batch: [1, 3, 256, 256])
        img_input = img_tensor.unsqueeze(0).to(device)
        
        # INFERENCE
        output = model(img_input)
        
        # Post-processing: Sigmoide -> Soglia 0.5 -> float
        pred_prob = torch.sigmoid(output)
        pred_mask = (pred_prob > 0.5).float().cpu().squeeze().numpy() # [256, 256]
        
        # Recuperiamo la ground truth per il plot
        gt_mask = mask_tensor.squeeze().numpy() # [256, 256]
        
        # Recuperiamo l'immagine originale per il plot (da tensore a numpy HWC)
        orig_img = img_tensor.permute(1, 2, 0).numpy()
        
        # --- SALVATAGGIO SU DISCO ---
        # Convertiamo la maschera da 0.0-1.0 a 0-255 uint8 per salvarla come immagine visibile
        pred_img_pil = Image.fromarray((pred_mask * 255).astype(np.uint8))
        save_path = os.path.join(OUTPUT_DIR, filename)
        pred_img_pil.save(save_path)
        
        # Aggiungiamo alla lista per il plot finale
        plot_data.append({
            'filename': filename,
            'orig': orig_img,
            'gt': gt_mask,
            'pred': pred_mask
        })

print(f"Generate {len(plot_data)} maschere.")

# --- PLOT DI TUTTE LE IMMAGINI ---
# Attenzione: se il validation set è enorme, questa immagine sarà molto alta verticale.
num_samples = len(plot_data)
fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))

print("Generazione grafico comparativo...")

# Gestione caso speciale se c'è solo 1 immagine nel validation set
if num_samples == 1:
    axes = np.expand_dims(axes, axis=0)

for idx, data in enumerate(plot_data):
    # Colonna 1: Immagine Originale
    axes[idx, 0].imshow(data['orig'])
    axes[idx, 0].set_title(f"Orig: {data['filename']}")
    axes[idx, 0].axis('off')
    
    # Colonna 2: Maschera Manuale (Ground Truth)
    axes[idx, 1].imshow(data['gt'], cmap='gray')
    axes[idx, 1].set_title("Manual (Verità)")
    axes[idx, 1].axis('off')
    
    # Colonna 3: Maschera Predetta dalla U-Net
    axes[idx, 2].imshow(data['pred'], cmap='gray')
    axes[idx, 2].set_title("Predizione U-Net")
    axes[idx, 2].axis('off')

plt.tight_layout()
plt.show()