# 🏭 Advanced Segmentation Training (Galicia Toshiba Dataset)
**Objective**: Train a U-Net++ model to detect 4 classes of industrial defects on large resolution images.

**Key Features**:
1.  **Smart Sampling**: Instead of random crops, we pre-scan the dataset to find where the defects are.
2.  **Class Oversampling**: Rare classes (e.g. Class 2) are sampled much more frequently.
3.  **Weighted Loss**: The loss function penalizes mistakes on rare classes much more heavily.
4.  **Full Image Inference**: Visualization block to sliding-window predict entire original images.

In [None]:
# 1. SETUP & IMPORTS
import os
import cv2
import torch
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm
import random
import matplotlib.pyplot as plt
from collections import defaultdict
from PIL import Image

# Check GPU
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"🚀 Device: {DEVICE}")

In [None]:
# 2. CONFIGURATION
# ======================
BASE_DIR = "/home/dmore/code/TFM_David/CNN_Galicia_Toshiba/datasets/unet_dataset"
TRAIN_DIR = os.path.join(BASE_DIR, "train")
VAL_DIR = os.path.join(BASE_DIR, "val")

# Classes: 0=Safe, 1=DefectA, 2=DefectB, 3=DefectC
CLASSES = 4 
TILE_SIZE = 1024
BATCH_SIZE = 4 

# Training Hyperparams
EPOCHS = 30
PATCHES_PER_IMAGE = 20 # Effectively increases 'epoch' size to cover more area
LEARNING_RATE = 1e-4

# Penalization Weights (Higher = More focus on that class)
# Rough estimate based on frequency: Class 2 is 10x rarer than Class 1/3
CLASS_WEIGHTS = [0.1, 5.0, 50.0, 5.0]

In [None]:
# 3. SMART PRE-SCANNING
# ======================
def scan_dataset_defects(root_dir):
    """
    Scans all mask files to find coordinates of defects.
    Returns a dict: { class_id: [ (filename, center_y, center_x), ... ] }
    """
    print(f"🕵️‍♂️ Scanning {root_dir} for defects (this happens once)...")
    mask_dir = os.path.join(root_dir, 'masks')
    defect_registry = defaultdict(list)
    
    files = [f for f in os.listdir(mask_dir) if f.endswith(('.png', '.jpg'))]
    
    for f in tqdm(files):
        path = os.path.join(mask_dir, f)
        # Load mask (0, 1, 2, 3)
        try:
            mask = np.array(Image.open(path))
        except:
            continue
            
        unique_classes = np.unique(mask)
        
        # Determine sampling points for each class found
        for cls in unique_classes:
            if cls == 0: continue # Skip background
            
            # Find coordinates of this defect
            ys, xs = np.where(mask == cls)
            
            if len(ys) > 0:
                # We take a few sample points from this defect blob
                # (e.g. center, and maybe random points if it's huge)
                # For simplicity, let's take the centroid of the blob
                cy = int(np.mean(ys))
                cx = int(np.mean(xs))
                defect_registry[cls].append((f, cy, cx))
                
    print("✅ Scan complete.")
    for c, items in defect_registry.items():
        print(f"  - Class {c}: Found {len(items)} instances.")
        
    return defect_registry

# Run scan on TRAIN only
train_defects = scan_dataset_defects(TRAIN_DIR)

In [None]:
# 4. DATASET CLASS WITH SMART SAMPLING
# ======================
class IndustrialSmartDataset(Dataset):
    def __init__(self, root_dir, defect_registry=None, transform=None, patches_per_img=10):
        self.root_dir = root_dir
        self.transform = transform
        self.img_dir = os.path.join(root_dir, 'images')
        self.mask_dir = os.path.join(root_dir, 'masks')
        self.patches_per_img = patches_per_img
        
        self.images = [f for f in os.listdir(self.img_dir) if f.lower().endswith(('.jpg', '.png'))]
        self.defect_registry = defect_registry
        
        # Flatten available defects for easy random access
        self.all_defects = []
        if self.defect_registry:
            for cls, items in self.defect_registry.items():
                # Oversample rare Class 2 explicitly here if needed, 
                # but we handle it via probability below.
                self.all_defects.extend([(cls, item) for item in items])
        
    def __len__(self):
        # Virtual length: We want to extract N patches per image per epoch
        return len(self.images) * self.patches_per_img

    def __getitem__(self, idx):
        # Map virtual index to actual image
        img_idx = idx % len(self.images)
        img_name = self.images[img_idx]
        
        # DECISION: Defect Crop or Random Background Crop?
        # 60% chance to force-jump to a known defect (if any exist globally)
        use_defect_crop = (random.random() < 0.6) and (len(self.all_defects) > 0)
        
        img_path = os.path.join(self.img_dir, img_name)
        mask_name = os.path.splitext(img_name)[0] + ".png" # Assuming png masks
        # Fallback if mask has different ext
        if not os.path.exists(os.path.join(self.mask_dir, mask_name)):
             mask_name = os.path.splitext(img_name)[0] + ".jpg"
             
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        # Load Full Image/Mask (Optimization: Could use partial loading libraries)
        # But cv2 is fast enough for these sizes usually given RAM
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, 0) # Grayscale
        
        h, w = mask.shape
        crop_h, crop_w = TILE_SIZE, TILE_SIZE
        
        if use_defect_crop:
            # Pick a RANDOM defect from our registry (Cross-Image Mixing)
            # WAIT: We must crop from the CURRENT image if we loaded it.
            # actually better strategy: 
            # 1. Pick a random defect from registry first.
            # 2. Load THAT image. 
            # But standard Pytorch logic maps idx -> image.
            # Let's stick to: "If this image has defects, try to center on them"
            
            # Find defects in THIS image
            ys, xs = np.where(mask > 0)
            if len(ys) > 0:
                # Prioritize Class 2 if present in this image
                ys_c2, xs_c2 = np.where(mask == 2)
                if len(ys_c2) > 0:
                     # High priority to Class 2
                     center_idx = random.randint(0, len(ys_c2)-1)
                     cy, cx = ys_c2[center_idx], xs_c2[center_idx]
                else:
                     center_idx = random.randint(0, len(ys)-1)
                     cy, cx = ys[center_idx], xs[center_idx]
                
                # Calculate coordinates
                y1 = max(0, min(cy - crop_h // 2, h - crop_h))
                x1 = max(0, min(cx - crop_w // 2, w - crop_w))
            else:
                # Pass through to random crop
                y1 = random.randint(0, h - crop_h)
                x1 = random.randint(0, w - crop_w)
        else:
            # Completely random crop (Background sample)
            y1 = random.randint(0, h - crop_h)
            x1 = random.randint(0, w - crop_w)
            
        # Ensure bounds (in case image < tile size, though unlikely here)
        if h < crop_h: y1 = 0
        if w < crop_w: x1 = 0
        
        image = image[y1:y1+crop_h, x1:x1+crop_w]
        mask = mask[y1:y1+crop_h, x1:x1+crop_w]
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']
            
        # Ensure mask is long (int64) for torch loss
        return image, mask.long()

In [None]:
# 5. AUGMENTATION PIPELINE
# ======================
train_transform = A.Compose([
    A.PadIfNeeded(min_height=TILE_SIZE, min_width=TILE_SIZE, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    A.RandomCrop(height=TILE_SIZE, width=TILE_SIZE, p=1.0),
    
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    
    # Brightness/Contrast to simulate different lighting
    A.MultiplicativeNoise(multiplier=(0.9, 1.1), p=0.2),
    A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05, p=0.3),
    
    # Normalize is vital for pre-trained encoders
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
], is_check_shapes=False) # Disable shape check warning

val_transform = A.Compose([
    A.PadIfNeeded(min_height=TILE_SIZE, min_width=TILE_SIZE, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    A.CenterCrop(height=TILE_SIZE, width=TILE_SIZE, p=1.0),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
], is_check_shapes=False)

In [None]:
# 6. MODEL SETUP & TRAINING LOOP
# ======================

# Datasets
train_dataset = IndustrialSmartDataset(TRAIN_DIR, defect_registry=train_defects, transform=train_transform, patches_per_img=PATCHES_PER_IMAGE)
val_dataset = IndustrialSmartDataset(VAL_DIR, defect_registry=None, transform=val_transform, patches_per_img=5) # Scan less val

# Loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"Dataset Sizes (Virtual): Train={len(train_dataset)}, Val={len(val_dataset)}")

# Model: U-Net++ with ResNet34 Encoder
model = smp.UnetPlusPlus(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=CLASSES,
)
model.to(DEVICE)

# Loss Function (Weighted)
class_weights_t = torch.tensor(CLASS_WEIGHTS).float().to(DEVICE)
loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights_t)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True)

# Training Loop
best_loss = float('inf')
train_logs = []
val_logs = []

print("🔥 STARTING TRAINING...")

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    
    # Progress bar
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for images, masks in loop:
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)
        
        optimizer.zero_grad()
        logits = model(images) # Output: [B, 4, H, W]
        
        loss = loss_fn(logits, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        loop.set_postfix(loss=loss.item())
        
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)
            logits = model(images)
            loss = loss_fn(logits, masks)
            val_loss += loss.item()
            
    avg_train_loss = epoch_loss / len(train_loader)
    avg_val_loss = val_loss / len(val_loader)
    
    train_logs.append(avg_train_loss)
    val_logs.append(avg_val_loss)
    scheduler.step(avg_val_loss)
    
    print(f"   📉 Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    
    if avg_val_loss < best_loss:
        best_loss = avg_val_loss
        torch.save(model.state_dict(), "best_model_smart_galicia.pth")
        print("   💾 Model Saved!")
        
# Plot History
plt.figure(figsize=(10,5))
plt.plot(train_logs, label='Train')
plt.plot(val_logs, label='Validation')
plt.title("Training Loss")
plt.legend()
plt.show()

In [None]:
# 7. INFERENCE & VISUALIZATION BLOCK
# ======================
def predict_full_image(model, img_path, tile_size=1024):
    # Load and Normalize
    original_img = cv2.imread(img_path)
    original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
    h, w, c = original_img.shape
    
    # Pad to multiple of tile_size to avid borders issues
    pad_h = (tile_size - (h % tile_size)) % tile_size
    pad_w = (tile_size - (w % tile_size)) % tile_size
    
    padded_img = cv2.copyMakeBorder(original_img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=0)
    final_mask = np.zeros((padded_img.shape[0], padded_img.shape[1]), dtype=np.uint8)
    
    # Sliding Window
    model.eval()
    with torch.no_grad():
        for y in range(0, padded_img.shape[0], tile_size):
            for x in range(0, padded_img.shape[1], tile_size):
                tile = padded_img[y:y+tile_size, x:x+tile_size]
                
                # Transform
                # Manual normalize to match Albumentations
                tile_tensor = A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(image=tile)['image']
                tile_tensor = ToTensorV2()(image=tile_tensor)['image']
                tile_tensor = tile_tensor.unsqueeze(0).to(DEVICE)
                
                logits = model(tile_tensor)
                preds = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
                
                final_mask[y:y+tile_size, x:x+tile_size] = preds
                
    # Crop back
    final_mask = final_mask[:h, :w]
    return original_img, final_mask

# --- TEST ON A FEW IMAGES ---
# Pick samples from different classes according to our scan results
test_samples = []
# Try to find one of each class from train data for sanity check
for cls in [1, 2, 3]:
    if cls in train_defects and len(train_defects[cls]) > 0:
        fname = train_defects[cls][0][0]
        test_samples.append(fname)

# If list empty, just take randoms
if not test_samples:
    test_samples = os.listdir(TRAIN_DIR + '/images')[:3]

model.load_state_dict(torch.load("best_model_smart_galicia.pth"))

for sample_name in set(test_samples):
    print(f"🔍 Predicting {sample_name}...")
    full_path = os.path.join(TRAIN_DIR, 'images', sample_name)
    if not os.path.exists(full_path): continue
        
    img, pred_mask = predict_full_image(model, full_path)
    
    # Load Ground Truth for comparison
    gt_path = os.path.join(TRAIN_DIR, 'masks', os.path.splitext(sample_name)[0]+'.png')
    gt_mask = cv2.imread(gt_path, 0)
    
    # VISUALIZE
    plt.figure(figsize=(18, 6))
    
    plt.subplot(1, 4, 1)
    plt.imshow(img)
    plt.title("Original Image")
    plt.axis('off')
    
    plt.subplot(1, 4, 2)
    plt.imshow(gt_mask, cmap='viridis', vmin=0, vmax=3)
    plt.title("Ground Truth")
    plt.axis('off')
    
    plt.subplot(1, 4, 3)
    plt.imshow(pred_mask, cmap='viridis', vmin=0, vmax=3)
    plt.title("Prediction")
    plt.axis('off')
    
    plt.subplot(1, 4, 4)
    # Overlay Prediction on Image
    # Create RGB mask: 0=Black, 1=Red, 2=Green, 3=Blue
    rgb_mask = np.zeros_like(img)
    rgb_mask[pred_mask == 1] = [255, 0, 0]   # Red
    rgb_mask[pred_mask == 2] = [0, 255, 0]   # Green
    rgb_mask[pred_mask == 3] = [0, 0, 255]   # Blue
    
    plt.imshow(img)
    plt.imshow(rgb_mask, alpha=0.6)
    plt.title("Overlay (Pred)")
    plt.axis('off')
    
    plt.show()