In [1]:
import os
import cv2
import numpy as np
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
# --------------------------------------------------------------------------------
# 1. Definitions: Directories and Parameters
# --------------------------------------------------------------------------------

# Root folders for original datasets
DATA_ROOT = "data"
DOMAINS = ["Dichtflächen", "Bonding", "Wirecheck"]

# Each domain has subfolders: IO (only if present) and NIO
# Example structure:
# datasets/Dichtflächen/IO/
# datasets/Dichtflächen/NIO/
# datasets/Bonding/IO/
# datasets/Bonding/NIO/
# datasets/Wirecheck/NIO/  (no IO)

# Output folder for patches
PATCH_ROOT = "patches"
PATCH_SIZE = 256
STRIDE = 128  # 50% overlap

# Augmentation pipeline (for training)
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=30, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.2),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

In [None]:
# --------------------------------------------------------------------------------
# 2. Helper Function: Extract patches from one image and its mask
# --------------------------------------------------------------------------------

def extract_patches(image, mask, patch_size=PATCH_SIZE, stride=STRIDE):
    """
    Splits a (H, W, C) image and its (H, W) mask into overlapping patches.
    Returns a list of tuples: (img_patch, mask_patch, y, x).
    y and x are the top-left coordinates of the patch in the original image.
    """
    patches = []
    h, w = image.shape[:2]
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            img_patch = image[y:y + patch_size, x:x + patch_size]
            mask_patch = mask[y:y + patch_size, x:x + patch_size]
            patches.append((img_patch, mask_patch, y, x))
    return patches

# --------------------------------------------------------------------------------
# 3. Step 1: Create patch directories and extract all patches
# --------------------------------------------------------------------------------

# Ensure patch output directories exist
for split in ["train", "val", "test"]:
    for subfolder in ["images", "masks"]:
        os.makedirs(os.path.join(PATCH_ROOT, split, subfolder), exist_ok=True)

# We'll build a DataFrame with columns:
# ['patch_path', 'mask_path', 'label', 'domain', 'orig_image', 'y', 'x']
records = []

# Iterate over domains
for domain_id, domain in enumerate(DOMAINS):
    domain_folder = os.path.join(DATA_ROOT, domain)
    
    # Paths for NIO; a mask file is assumed to exist next to the image with the same base name
    nio_folder = os.path.join(domain_folder, "NIO")
    # IO folder may not exist for Wirecheck
    io_folder = os.path.join(domain_folder, "IO") if os.path.isdir(os.path.join(domain_folder, "IO")) else None
    
    # Process NIO images first
    for fname in os.listdir(nio_folder):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg", ".tif", ".tiff")):
            continue
        img_path = os.path.join(nio_folder, fname)
        # Assume mask has same name but in a "Masks" folder or with "_mask" suffix; adjust as needed
        # Example: if mask files are in a subfolder "masks" inside the domain folder:
        mask_name = os.path.splitext(fname)[0] + "_mask.png"
        mask_path = os.path.join(domain_folder, "Masks", mask_name)
        if not os.path.exists(mask_path):
            print(f"Mask not found for {img_path}, skipping.")
            continue
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        patches = extract_patches(image, mask)
        for img_patch, mask_patch, y, x in patches:
            # Determine label: if any pixel in mask_patch > 0, label = NIO (1), else IO (0)
            label = 1 if np.any(mask_patch > 0) else 0
            
            # Determine which split this original image belongs to later (we'll assign splits after gathering all records)
            records.append({
                "img_patch": img_patch,
                "mask_patch": mask_patch,
                "label": label,
                "domain": domain,
                "orig_image": fname,
                "y": y,
                "x": x
            })
    
    # Process IO images if folder exists (skip for Wirecheck)
    if io_folder:
        for fname in os.listdir(io_folder):
            if not fname.lower().endswith((".png", ".jpg", ".jpeg", ".tif", ".tiff")):
                continue
            img_path = os.path.join(io_folder, fname)
            # For IO images, create an empty mask (all zeros)
            image = cv2.imread(img_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            h, w = image.shape[:2]
            mask = np.zeros((h, w), dtype=np.uint8)
            
            patches = extract_patches(image, mask)
            for img_patch, mask_patch, y, x in patches:
                # All patches are IO (0)
                records.append({
                    "img_patch": img_patch,
                    "mask_patch": mask_patch,
                    "label": 0,
                    "domain": domain,
                    "orig_image": fname,
                    "y": y,
                    "x": x
                })

# --------------------------------------------------------------------------------
# 4. Step 2: Special handling for Wirecheck (pseudo-IO patches)
# --------------------------------------------------------------------------------

# Identify Wirecheck records in our list for NIO-only domain
wire_records = [r for r in records if r["domain"] == "Wirecheck"]

# Define how many pseudo-IO patches you want per original Wirecheck image
# (e.g., 3 pseudo-IO patches per image)
PSEUDO_IO_PER_IMAGE = 3
new_pseudo_records = []

# Process each Wirecheck image once
processed_images = set()
for record in wire_records:
    fname = record["orig_image"]
    if fname in processed_images:
        continue
    processed_images.add(fname)
    
    # Load full image and mask again
    img_path = os.path.join(DATA_ROOT, "Wirecheck", "NIO", fname)
    mask_name = os.path.splitext(fname)[0] + "_mask.png"
    mask_path = os.path.join(DATA_ROOT, "Wirecheck", "Masks", mask_name)
    image_full = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    mask_full = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    h, w = mask_full.shape
    
    # Randomly sample background coordinates where mask == 0 for full patches
    attempts = 0
    found = 0
    while found < PSEUDO_IO_PER_IMAGE and attempts < 100:
        ry = np.random.randint(0, h - PATCH_SIZE + 1)
        rx = np.random.randint(0, w - PATCH_SIZE + 1)
        patch_mask = mask_full[ry:ry + PATCH_SIZE, rx:rx + PATCH_SIZE]
        if np.all(patch_mask == 0):
            # This is a valid background patch
            patch_image = image_full[ry:ry + PATCH_SIZE, rx:rx + PATCH_SIZE]
            new_pseudo_records.append({
                "img_patch": patch_image,
                "mask_patch": np.zeros((PATCH_SIZE, PATCH_SIZE), dtype=np.uint8),
                "label": 0,  # IO
                "domain": "Wirecheck",
                "orig_image": fname,
                "y": ry,
                "x": rx
            })
            found += 1
        attempts += 1

# Append pseudo-IO records to main records list
records.extend(new_pseudo_records)

# --------------------------------------------------------------------------------
# 5. Step 3: Build DataFrame and assign train/val/test splits at image level
# --------------------------------------------------------------------------------

# Create a DataFrame for easier manipulation
df = pd.DataFrame(records)

# We need to split per domain, ensuring that all patches from a given orig_image go to same split
df["split"] = ""  # placeholder

split_ratios = {"train": 0.7, "val": 0.15, "test": 0.15}

for domain in DOMAINS:
    # Get unique image names in this domain
    unique_images = df[df["domain"] == domain]["orig_image"].unique()
    # Shuffle
    np.random.shuffle(unique_images)
    
    n_total = len(unique_images)
    n_train = int(split_ratios["train"] * n_total)
    n_val = int(split_ratios["val"] * n_total)
    
    train_imgs = unique_images[:n_train]
    val_imgs = unique_images[n_train:n_train + n_val]
    test_imgs = unique_images[n_train + n_val:]
    
    # Assign splits
    df.loc[(df["domain"] == domain) & (df["orig_image"].isin(train_imgs)), "split"] = "train"
    df.loc[(df["domain"] == domain) & (df["orig_image"].isin(val_imgs)), "split"] = "val"
    df.loc[(df["domain"] == domain) & (df["orig_image"].isin(test_imgs)), "split"] = "test"

# --------------------------------------------------------------------------------
# 6. Step 4: Save patches as image files and update paths in DataFrame
# --------------------------------------------------------------------------------

def save_patches(df_split, split_name):
    """
    Save image patches and mask patches to disk and update their paths in the DataFrame.
    """
    split_dir = os.path.join(PATCH_ROOT, split_name)
    img_dir = os.path.join(split_dir, "images")
    mask_dir = os.path.join(split_dir, "masks")
    
    for idx, row in df_split.iterrows():
        domain = row["domain"]
        orig_img_name = os.path.splitext(row["orig_image"])[0]
        y, x = row["y"], row["x"]
        label = row["label"]
        
        # Construct filenames
        patch_basename = f"{domain}_{orig_img_name}_y{y}_x{x}_lbl{label}.png"
        img_output_path = os.path.join(img_dir, patch_basename)
        mask_output_path = os.path.join(mask_dir, patch_basename)
        
        # Save
        cv2.imwrite(img_output_path, cv2.cvtColor(row["img_patch"], cv2.COLOR_RGB2BGR))
        cv2.imwrite(mask_output_path, row["mask_patch"])
        
        # Update DataFrame paths
        df.at[idx, "patch_img_path"] = img_output_path
        df.at[idx, "patch_mask_path"] = mask_output_path

# Apply saving for each split
for split_name in ["train", "val", "test"]:
    df_split = df[df["split"] == split_name].copy()
    save_patches(df_split, split_name)

# Keep only necessary columns
df_final = df[["patch_img_path", "patch_mask_path", "label", "domain", "split"]].copy()

# Save DataFrame to CSV for record-keeping
df_final.to_csv("patches_dataset.csv", index=False)

# --------------------------------------------------------------------------------
# 7. Step 5: Define PyTorch Dataset and DataLoader with WeightedRandomSampler
# --------------------------------------------------------------------------------

class PatchDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        """
        PyTorch Dataset for loading image/mask patches based on DataFrame.
        """
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = cv2.imread(row["patch_img_path"])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(row["patch_mask_path"], cv2.IMREAD_GRAYSCALE)
        label = int(row["label"])
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"].unsqueeze(0)  # shape: 1xHxW
        
        return image, mask, label

# Create train DataFrame for DataLoader
df_train = df_final[df_final["split"] == "train"].reset_index(drop=True)

# Compute sample weights based on label imbalance
label_counts = df_train["label"].value_counts().to_dict()  # e.g., {0: 1000, 1: 600}
weights_for_classes = {cls: 1.0 / count for cls, count in label_counts.items()}
sample_weights = [weights_for_classes[row["label"]] for _, row in df_train.iterrows()]

sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# Instantiate Dataset and DataLoader
train_dataset = PatchDataset(df_train, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=8, sampler=sampler, num_workers=4)

# For validation and testing (no sampler, just sequential)
df_val = df_final[df_final["split"] == "val"].reset_index(drop=True)
val_dataset = PatchDataset(df_val, transform=A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
]))
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

df_test = df_final[df_final["split"] == "test"].reset_index(drop=True)
test_dataset = PatchDataset(df_test, transform=A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
]))
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=4)

# --------------------------------------------------------------------------------
# 8. Verification: Print summary of splits and class distribution
# --------------------------------------------------------------------------------

def print_split_info(df_split, split_name):
    total = len(df_split)
    class_counts = df_split["label"].value_counts().to_dict()
    domain_counts = df_split["domain"].value_counts().to_dict()
    print(f"--- {split_name.upper()} ---")
    print(f"Total patches: {total}")
    print(f"Class distribution (0=IO, 1=NIO): {class_counts}")
    print(f"Domain distribution: {domain_counts}")
    print()

print_split_info(df_final[df_final["split"] == "train"], "train")
print_split_info(df_final[df_final["split"] == "val"], "val")
print_split_info(df_final[df_final["split"] == "test"], "test")

# You now have:
# - train_loader for balanced training
# - val_loader and test_loader for evaluation
# - "patches_dataset.csv" listing all patch files with labels and splits

# --------------------------------------------------------------------------------
# End of Script
# --------------------------------------------------------------------------------

