# Imports & GPU memory growth (optional)

In [None]:
import numpy as np
import cv2
import albumentations as A
import matplotlib.pyplot as plt

from pathlib import Path
from dataclasses import dataclass
from collections import Counter
from typing import Optional

# Optional TF GPU mem growth (only if TF is actually used later)
try:
    import tensorflow as tf
    gpus = tf.config.list_physical_devices("GPU")
    for g in gpus:
        tf.config.experimental.set_memory_growth(g, True)
    print(f"[INFO] TF GPUs: {len(gpus)} (memory growth enabled)" if gpus else "[INFO] TF GPUs: none")
except Exception as e:
    print(f"[INFO] TensorFlow not used / not available ({e})")


# Config

In [None]:
@dataclass(frozen=True)
class AugConfig:
    x_train_path: Path = Path("npyBilah/x_train.npy")
    y_train_path: Path = Path("npyBilah/y_train.npy")

    # If y is one-hot and the class order is reversed (last column = class 0), set True.
    reverse_onehot: bool = False

    # Geometry augmentation (white background)
    shift_limit: float = 0.06
    scale_limit: float = 0.06
    rotate_limit: int = 5
    perspective_scale: tuple = (0.02, 0.04)

    # Mask extraction
    white_thresh: int = 245

    # Photometric (applied only on foreground mask)
    brightness_limit: float = 0.2
    contrast_limit: float = 0.2
    blur_p: float = 0.2
    blur_k: int = 3
    noise_p: float = 0.2
    noise_var: tuple = (10.0, 50.0)

    # Aug strategy
    aug_per_image: int = 2
    balance_to_max: bool = True

    # Size handling
    expected_size: tuple = (512, 512)  # (H,W) expected; will resize if mismatch

    # Random
    seed: int = 42

CFG = AugConfig()
np.random.seed(CFG.seed)
print(CFG)


# Label utilities (safe, consistent)

In [None]:
def to_int_labels(y: np.ndarray) -> np.ndarray:
    """Accept integer labels or one-hot -> returns int labels."""
    y = np.asarray(y)
    if y.ndim > 1 and y.shape[-1] > 1:
        return np.argmax(y, axis=1).astype(np.int32)
    return y.astype(np.int32).reshape(-1)

def maybe_reverse_labels(y_int: np.ndarray, num_classes: int, reverse: bool) -> np.ndarray:
    """Optional reverse mapping if the one-hot columns were reversed."""
    if not reverse:
        return y_int
    return (num_classes - 1) - y_int

def to_onehot(y_int: np.ndarray, num_classes: int, reverse: bool) -> np.ndarray:
    """Convert int labels to one-hot; optionally reverse to match original column ordering."""
    if reverse:
        # If original one-hot is reversed, we must reverse back for saving in that convention.
        y_for_onehot = (num_classes - 1) - y_int
    else:
        y_for_onehot = y_int
    return np.eye(num_classes, dtype=np.int32)[y_for_onehot]


# Image/mask utilities + photometric-on-mask

In [None]:
def ensure_u8_3ch_bgr(img: np.ndarray) -> np.ndarray:
    """Ensure uint8, 3-channel BGR for OpenCV."""
    arr = img
    if np.issubdtype(arr.dtype, np.floating):
        if arr.max() <= 1.0 + 1e-6:
            arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
        else:
            arr = np.round(arr).clip(0, 255).astype(np.uint8)
    else:
        arr = arr.astype(np.uint8)

    if arr.ndim == 2:
        arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
    elif arr.ndim == 3 and arr.shape[2] == 4:
        arr = cv2.cvtColor(arr, cv2.COLOR_BGRA2BGR)
    elif arr.ndim == 3 and arr.shape[2] == 3:
        pass
    else:
        arr = arr[:, :, :3]
    return arr


def make_foreground_mask_from_white(img_u8_bgr: np.ndarray, white_thresh: int = 245) -> np.ndarray:
    """
    Foreground mask (True = object) from near-white background.
    Background if ALL channels >= white_thresh.
    """
    b, g, r = cv2.split(img_u8_bgr)
    bg = (b >= white_thresh) & (g >= white_thresh) & (r >= white_thresh)
    mask = ~bg

    mask_u8 = (mask.astype(np.uint8) * 255)
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    mask_u8 = cv2.morphologyEx(mask_u8, cv2.MORPH_OPEN, k)
    mask_u8 = cv2.morphologyEx(mask_u8, cv2.MORPH_CLOSE, k)

    return mask_u8 > 0


def photometric_on_mask(
    img_u8_bgr: np.ndarray,
    mask_bool: np.ndarray,
    brightness_limit: float = 0.2,
    contrast_limit: float = 0.2,
    blur_p: float = 0.2,
    blur_k: int = 3,
    noise_p: float = 0.2,
    noise_var: tuple = (10.0, 50.0),
) -> np.ndarray:
    """
    Apply brightness/contrast/blur/noise only inside foreground mask.
    Force background to pure white.
    """
    out = img_u8_bgr.copy()

    if mask_bool.any():
        # Brightness/contrast
        if np.random.rand() < 0.5:
            alpha = 1.0 + np.random.uniform(-contrast_limit, contrast_limit)
            beta = np.random.uniform(-brightness_limit, brightness_limit) * 255.0
            adj = np.clip(out.astype(np.float32) * alpha + beta, 0, 255).astype(np.uint8)
            out[mask_bool] = adj[mask_bool]

        # Blur
        if np.random.rand() < blur_p:
            k = int(blur_k)
            if k % 2 == 0:
                k += 1
            k = max(k, 3)
            blr = cv2.GaussianBlur(out, (k, k), 0)
            out[mask_bool] = blr[mask_bool]

        # Noise
        if np.random.rand() < noise_p:
            var = np.random.uniform(noise_var[0], noise_var[1])
            sigma = max(1e-6, np.sqrt(var))
            noise = np.random.normal(0, sigma, out.shape).astype(np.float32)
            ns = np.clip(out.astype(np.float32) + noise, 0, 255).astype(np.uint8)
            out[mask_bool] = ns[mask_bool]

    out[~mask_bool] = 255
    return out


def post_to_original_dtype(img_u8_bgr: np.ndarray, ref: np.ndarray) -> np.ndarray:
    """
    Convert augmented uint8 BGR to match X_train dtype scale:
      - if original is float in [0,1], output float in [0,1]
      - else keep original dtype (usually uint8)
    """
    is_float = np.issubdtype(ref.dtype, np.floating)
    if is_float and ref.max() <= 1.0 + 1e-6:
        return (img_u8_bgr.astype(np.float32) / 255.0).astype(ref.dtype)
    return img_u8_bgr.astype(ref.dtype)


# Albumentations geometry-only (white border) + single-step augment

In [None]:
augment_geom = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(
            shift_limit=CFG.shift_limit,
            scale_limit=CFG.scale_limit,
            rotate_limit=CFG.rotate_limit,
            border_mode=cv2.BORDER_CONSTANT,
            value=255,
            mask_value=0,
            p=0.7,
        ),
        A.Perspective(
            scale=CFG.perspective_scale,
            keep_size=True,
            pad_mode=cv2.BORDER_CONSTANT,
            pad_val=255,
            mask_pad_val=0,
            p=0.2,
        ),
    ],
    additional_targets={"mask": "mask"},
)

def augment_once(img: np.ndarray) -> np.ndarray:
    """
    One augmentation step:
      - build foreground mask from white bg
      - apply geometry transform to image & mask
      - enforce white background after geometry
      - apply photometric only on foreground
      - resize to expected_size if needed
    """
    img_u8 = ensure_u8_3ch_bgr(img)
    mask0 = make_foreground_mask_from_white(img_u8, white_thresh=CFG.white_thresh)

    data = augment_geom(image=img_u8, mask=(mask0.astype(np.uint8) * 255))
    g_img = data["image"]
    g_mask = data["mask"] > 0

    # enforce white background (helps against interpolation artifacts)
    g_img[~g_mask] = 255

    out = photometric_on_mask(
        g_img, g_mask,
        brightness_limit=CFG.brightness_limit,
        contrast_limit=CFG.contrast_limit,
        blur_p=CFG.blur_p,
        blur_k=CFG.blur_k,
        noise_p=CFG.noise_p,
        noise_var=CFG.noise_var,
    )

    H, W = CFG.expected_size
    if out.shape[:2] != (H, W):
        out = cv2.resize(out, (W, H), interpolation=cv2.INTER_LINEAR)

    return out


# Build augmented + balanced training set

In [None]:
# Load
X_train = np.load(CFG.x_train_path)
y_train = np.load(CFG.y_train_path)

y_int = to_int_labels(y_train)
num_classes = y_train.shape[-1] if (y_train.ndim > 1 and y_train.shape[-1] > 1) else int(y_int.max() + 1)

# Optional reverse (ONLY if you know your one-hot ordering is reversed)
y_int = maybe_reverse_labels(y_int, num_classes=num_classes, reverse=CFG.reverse_onehot)

print("[INFO] num_classes:", num_classes)
print("[INFO] original distribution:", Counter(y_int))

class_counts = Counter(y_int)
max_count = max(class_counts.values()) if class_counts else 0

aug_images = []
aug_labels = []

for cls in range(num_classes):
    idx = np.where(y_int == cls)[0]
    n = len(idx)
    print(f"\n-- Class {cls}: {n} original sample(s)")

    if n == 0:
        continue

    # A) fixed aug_per_image
    for i in idx:
        for _ in range(CFG.aug_per_image):
            out_u8 = augment_once(X_train[i])
            aug_images.append(post_to_original_dtype(out_u8, X_train))
            aug_labels.append(cls)

    # B) balance to max_count (optional)
    if CFG.balance_to_max:
        current_total = n + (CFG.aug_per_image * n)
        needed = max(0, max_count - current_total)

        if needed > 0:
            reps = int(np.ceil(needed / n))
            count = 0
            for _ in range(reps):
                for i in idx:
                    if count >= needed:
                        break
                    out_u8 = augment_once(X_train[i])
                    aug_images.append(post_to_original_dtype(out_u8, X_train))
                    aug_labels.append(cls)
                    count += 1

    print(f"  -> added aug samples (so far): {len(aug_labels)}")

aug_images = np.asarray(aug_images, dtype=X_train.dtype)
aug_labels = np.asarray(aug_labels, dtype=np.int32)

X_train_bal = np.concatenate([X_train, aug_images], axis=0)
y_train_bal_int = np.concatenate([y_int, aug_labels], axis=0)

print("\n[INFO] final distribution:", Counter(y_train_bal_int))
print("[INFO] X_train_bal:", X_train_bal.shape, X_train_bal.dtype)


# Convert back to one-hot (matching original convention) + save (optional)

In [None]:
y_train_bal_onehot = to_onehot(y_train_bal_int, num_classes=num_classes, reverse=CFG.reverse_onehot)
print("[INFO] y_train_bal_onehot:", y_train_bal_onehot.shape, y_train_bal_onehot.dtype)

# Optional save
OUT_DIR = Path("npyBilah/aug")
OUT_DIR.mkdir(parents=True, exist_ok=True)

np.save(OUT_DIR / "X_train_aug.npy", X_train_bal)
np.save(OUT_DIR / "y_train_aug.npy", y_train_bal_onehot)

print("[OK] Saved:", OUT_DIR / "X_train_aug.npy")
print("[OK] Saved:", OUT_DIR / "y_train_aug.npy")


# Preview 1 image/class → 8 augment

In [None]:
def bgr_to_rgb(img_bgr: np.ndarray) -> np.ndarray:
    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

def preview_augments_for_class(X: np.ndarray, y_int: np.ndarray, class_id: int, repeats: int = 8):
    idx = np.where(y_int == class_id)[0]
    if idx.size == 0:
        print(f"[SKIP] No samples for class {class_id}")
        return

    j = np.random.choice(idx, 1, replace=False)[0]
    orig_u8 = ensure_u8_3ch_bgr(X[j])
    augs = [augment_once(orig_u8) for _ in range(repeats)]

    fig, axes = plt.subplots(3, 3, figsize=(9, 9))
    axes = axes.reshape(3, 3)

    axes[0, 0].imshow(bgr_to_rgb(orig_u8))
    axes[0, 0].set_title("Original")
    axes[0, 0].axis("off")

    k = 0
    for r in range(3):
        for c in range(3):
            if r == 0 and c == 0:
                continue
            axes[r, c].imshow(bgr_to_rgb(augs[k]))
            axes[r, c].set_title(f"Aug {k+1}")
            axes[r, c].axis("off")
            k += 1

    plt.suptitle(f"Class {class_id}: 1 sample → {repeats} augmentations", y=0.98)
    plt.tight_layout()
    plt.show()

classes = sorted(np.unique(y_int).tolist())
for cls in classes:
    print(f"\n=== Class {cls} ===")
    preview_augments_for_class(X_train, y_int, cls, repeats=8)


# Load .npy augment + per-class grid preview

In [None]:
from collections import Counter
import math

X_aug = np.load("npyBilah/aug/X_train_aug.npy")
y_aug = np.load("npyBilah/aug/y_train_aug.npy")

y_aug_int = to_int_labels(y_aug)
# If the stored one-hot is reversed, convert it back to natural int order for analysis:
y_aug_int = maybe_reverse_labels(y_aug_int, num_classes=y_aug.shape[-1], reverse=CFG.reverse_onehot)

print("[INFO] AUG distribution:", Counter(y_aug_int))

def ensure_rgb_u8_for_show(img: np.ndarray) -> np.ndarray:
    arr = img
    if np.issubdtype(arr.dtype, np.floating):
        if arr.max() <= 1.0 + 1e-6:
            arr = (arr * 255.0).clip(0, 255).astype(np.uint8)
        else:
            arr = np.round(arr).clip(0, 255).astype(np.uint8)
    else:
        arr = arr.astype(np.uint8)

    # assume BGR from pipeline
    if arr.ndim == 2:
        arr = cv2.cvtColor(arr, cv2.COLOR_GRAY2RGB)
    elif arr.ndim == 3 and arr.shape[2] == 3:
        arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
    elif arr.ndim == 3 and arr.shape[2] == 4:
        arr = cv2.cvtColor(arr, cv2.COLOR_BGRA2RGB)
    return arr

rng = np.random.default_rng(CFG.seed)

def show_aug_grid_for_class(cls_id: int, k: int = 8):
    idx = np.where(y_aug_int == cls_id)[0]
    if idx.size == 0:
        print(f"[SKIP] Class {cls_id}: no samples.")
        return

    idxs = idx if idx.size <= k else rng.choice(idx, size=k, replace=False)
    idxs = np.asarray(idxs)

    rows = 2
    cols = int(math.ceil(len(idxs) / rows))
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 2.2, rows * 2.2))
    axes = np.array(axes).reshape(rows, cols)

    for i in range(rows * cols):
        r, c = divmod(i, cols)
        ax = axes[r, c]
        ax.axis("off")
        if i < len(idxs):
            ax.imshow(ensure_rgb_u8_for_show(X_aug[idxs[i]]))
            ax.set_title(f"#{i+1}", fontsize=9)

    plt.suptitle(f"AUG only — Class {cls_id} | shown {len(idxs)}", y=1.02)
    plt.tight_layout()
    plt.show()

for cls in sorted(np.unique(y_aug_int).tolist()):
    show_aug_grid_for_class(cls, k=8)
