# Steatosis Segmentation using U-Net
Training a U-Net model to segment liver steatosis using the provided Training and Validation datasets.
This notebook implements the training pipeline including data loading, model definition, training loop, and evaluation.

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 20
INPUT_SIZE = (256, 256)
CHECKPOINT_DIR = 'checkpoints_steatosis'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Data paths (Modify these if your folder structure is different)
TRAIN_DIR = 'train'
VAL_DIR = 'val'

In [None]:
class SteatosisDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (string): Directory with 'image' and 'manual' subdirs.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.image_dir = os.path.join(root_dir, 'image')
        # MODIFICA: Usiamo la cartella corretta per il training
        self.mask_dir = os.path.join(root_dir, 'manual_py')
        self.transform = transform
        
        # Load file lists
        # We filter for common image extensions
        valid_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.tif')
        self.images = sorted([f for f in os.listdir(self.image_dir) if f.lower().endswith(valid_extensions)])
        self.masks = sorted([f for f in os.listdir(self.mask_dir) if f.lower().endswith(valid_extensions)])
        
        # Basic check
        if len(self.images) != len(self.masks):
            print(f"Warning: Number of images ({len(self.images)}) and masks ({len(self.masks)}) in {root_dir} do not match!")
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        mask_name = self.masks[idx] 
        
        # Construct full paths
        img_path = os.path.join(self.image_dir, img_name)
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        # Load image and mask
        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L") # Grayscale for mask
        
        # Apply transforms
        
        # Resize image (Bilineare va bene per le foto)
        resize_img = transforms.Resize(INPUT_SIZE, interpolation=transforms.InterpolationMode.BILINEAR)
        image = resize_img(image)
        
        # Resize mask (NEAREST è fondamentale per le maschere per mantenere 0 e 1 puri)
        resize_mask = transforms.Resize(INPUT_SIZE, interpolation=transforms.InterpolationMode.NEAREST)
        mask = resize_mask(mask)
        
        to_tensor = transforms.ToTensor()
        
        image = to_tensor(image)
        mask = to_tensor(mask)
        
        # Con manual_py i valori sono spesso molto bassi (0 e 1 su scala 255 vengono caricati quasi neri).
        # ToTensor normalizza in [0, 1]. Se il pixel era 1 (classe steatosi), diventa 1/255.
        # Se invece manual_py è salvata come 0 e 255, diventa 0 e 1.
        # Per sicurezza assoluta, binarizziamo qualunque cosa > 0.
        mask = (mask > 0).float()
        
        return image, mask

In [None]:
# Create Dataset instances
train_dataset = SteatosisDataset(TRAIN_DIR)
val_dataset = SteatosisDataset(VAL_DIR)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training set: {len(train_dataset)} images")
print(f"Validation set: {len(val_dataset)} images")

# Sanity check: Visualize one sample
temp_img, temp_mask = train_dataset[0]
print(f"Image tensor shape: {temp_img.shape}")
print(f"Mask tensor shape: {temp_mask.shape}")
print(f"Mask unique values: {torch.unique(temp_mask)}")

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(temp_img.permute(1, 2, 0))
plt.title("Sample Image")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(temp_mask.squeeze(), cmap='gray')
plt.title("Sample Mask")
plt.axis('off')
plt.show()

In [None]:
class DoubleConv(nn.Module):
    """Applies two consecutive conv-batchnorm-relu layers"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self,
                 in_channels=3,
                 out_channels=1,
                 init_filters=64,
                 depth=4,
                 bilinear=True):
        super(UNet, self).__init__()
        self.depth = depth
        self.down_layers = nn.ModuleList()
        self.up_layers = nn.ModuleList()
        self.pool = nn.MaxPool2d(2)

        # Encoder
        filters = init_filters
        for d in range(depth):
            conv = DoubleConv(in_channels, filters)
            self.down_layers.append(conv)
            in_channels = filters
            filters *= 2

        # Bottleneck
        self.bottleneck = DoubleConv(in_channels, filters)

        # Decoder
        for d in range(depth):
            filters //= 2
            if bilinear:
                up = nn.Sequential(
                    nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                    nn.Conv2d(filters * 2, filters, kernel_size=1)
                )
            else:
                up = nn.ConvTranspose2d(filters * 2, filters, kernel_size=2, stride=2)
            self.up_layers.append(nn.ModuleDict({
                'up': up,
                'conv': DoubleConv(filters * 2, filters)
            }))

        # Output layer
        self.out_conv = nn.Conv2d(init_filters, out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []
        for down in self.down_layers:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)

        for i in range(self.depth):
            skip = skip_connections[-(i+1)]
            up = self.up_layers[i]['up'](x)
            if up.size() != skip.size():
                # Resize in case of odd size mismatch
                up = F.interpolate(up, size=skip.shape[2:])
            x = torch.cat([skip, up], dim=1)
            x = self.up_layers[i]['conv'](x)

        return self.out_conv(x)

In [None]:
# Metric: Dice Coefficient
def dice_coefficient(pred, target, smooth=1e-5):
    # Pred is raw logits, apply sigmoid
    pred = torch.sigmoid(pred)
    # Binarize
    pred = (pred > 0.5).float()
    
    intersection = (pred * target).sum()
    dice = (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)
    return dice

# Initialize Model
model = UNet(in_channels=3, out_channels=1, init_filters=32, depth=4).to(device)

# Loss Function: BCEWithLogitsLoss is standard for binary segmentation
criterion = nn.BCEWithLogitsLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Model initialized.")

In [None]:
train_losses = []
val_losses = []
val_dices = []

best_dice = 0.0

print("Starting training...")

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    # Training Phase
    for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} - Train', leave=False):
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_train_loss = running_loss / len(train_dataset)
    train_losses.append(epoch_train_loss)
    
    # Validation Phase
    model.eval()
    running_val_loss = 0.0
    running_dice = 0.0
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} - Val', leave=False):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_val_loss += loss.item() * images.size(0)
            
            # Compute Dice
            running_dice += dice_coefficient(outputs, masks).item() * images.size(0)
            
    epoch_val_loss = running_val_loss / len(val_dataset)
    epoch_dice = running_dice / len(val_dataset)
    
    val_losses.append(epoch_val_loss)
    val_dices.append(epoch_dice)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Dice: {epoch_dice:.4f}")
    
    # Save Best Model
    if epoch_dice > best_dice:
        best_dice = epoch_dice
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'best_model.pth'))
        print(f"--> Best model saved with Dice: {best_dice:.4f}")

    # Save Checkpoint periodically
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f'model_epoch_{epoch+1}.pth'))

# Save training parameters for future inference
params = {
    "input_size": INPUT_SIZE,
    "in_channels": 3,
    "out_channels": 1,
    "init_filters": 32,
    "depth": 4,
    "best_dice": best_dice
}
with open(os.path.join(CHECKPOINT_DIR, 'training_params.json'), 'w') as f:
    json.dump(params, f, indent=4)

print("Training completed.")

In [None]:
# Plot Metrics
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(val_losses, label='Validation Loss', marker='o')
plt.title('Loss over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss (BCE)')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(val_dices, label='Validation Dice', color='green', marker='o')
plt.title('Dice Score over Epochs')
plt.xlabel('Epochs')
plt.ylabel('Dice Coefficient')
plt.legend()
plt.grid(True)
plt.show()

# Visualize Predictions on Validation Set
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, 'best_model.pth')))
model.eval()

images, masks = next(iter(val_loader))
images = images.to(device)
with torch.no_grad():
    outputs = model(images)
    preds = torch.sigmoid(outputs) > 0.5

# Convert to numpy for plotting
images_np = images.cpu().permute(0, 2, 3, 1).numpy()
masks_np = masks.cpu().squeeze().numpy()
preds_np = preds.cpu().squeeze().numpy()

n_plot = min(3, BATCH_SIZE)
plt.figure(figsize=(15, 5 * n_plot))

for i in range(n_plot):
    plt.subplot(n_plot, 3, i*3 + 1)
    plt.imshow(images_np[i])
    plt.title(f"Original Image {i+1}")
    plt.axis('off')
    
    plt.subplot(n_plot, 3, i*3 + 2)
    plt.imshow(masks_np[i], cmap='gray')
    plt.title(f"Ground Truth {i+1}")
    plt.axis('off')
    
    plt.subplot(n_plot, 3, i*3 + 3)
    plt.imshow(preds_np[i], cmap='gray')
    plt.title(f"Prediction {i+1}")
    plt.axis('off')

plt.tight_layout()
plt.show()