# README:

Train a neural network to automatically detect and segment damage. Use Encoder-decoder structure for image segmentation, and attention gates to help the model focus on actual damage areas rather than background noise. Use Focal Tversky loss as it heavily penalizes missed damage and detecting smaller damage areas.

In [None]:
# Imports
import os
import time
import random
import gc
import psutil
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor
import cv2
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds
from google.colab import drive


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Configuration
TARGET_SIZE = 256  # Reduced to 256 due to memory
MAX_WORKERS = 8
BATCH_SIZE = 16
MAX_SAMPLES = 6000  # Limit to first x images


# Setup paths
print("Setting up data paths")
dataset_dir = "/content/drive/MyDrive/artwork_data_for_masking2"
damaged_dir = os.path.join(dataset_dir, "damaged")
mask_dir = os.path.join(dataset_dir, "masks")



print(f"Dataset found: {dataset_dir}")
print(f"Damaged dir: {damaged_dir}")
print(f"Mask dir: {mask_dir}")

# Find files
def find_files(root_dir, suffix):
    file_paths = []
    for root, _, files in os.walk(root_dir):
        for file in files:
            if file.endswith(suffix):
                rel_path = os.path.relpath(os.path.join(root, file), root_dir)
                file_paths.append(rel_path)
    return sorted(file_paths)

def find_paired_files(damaged_dir, mask_dir):
    print("Discovering file pairs")

    # Find all damaged images and masks recursively (set up as so there are subfolders based on sub-art movements)
    damaged_files = find_files(damaged_dir, '_damaged.jpg')
    mask_files = find_files(mask_dir, '_mask.jpg')

    print(f"Found {len(damaged_files)} damaged images")
    print(f"Found {len(mask_files)} mask images")
    if not damaged_files or not mask_files:
        print(f"No files found. Check paths:")
        print(f"Damaged: {damaged_dir}")
        print(f"Masks: {mask_dir}")
        return []

    # Pair files by base path
    pairs = []
    for d in damaged_files:
        base = d.replace('_damaged.jpg', '')
        matching_masks = [m for m in mask_files if m.replace('_mask.jpg', '') == base]

        if not matching_masks:
            print(f"Warning: No mask found for {d}")
            continue
        elif len(matching_masks) > 1:
            print(f"Warning: Multiple masks found for {d}")
        category = os.path.dirname(d) if os.path.dirname(d) else "unknown"

        pairs.append({
            'damaged': os.path.join(damaged_dir, d),
            'mask': os.path.join(mask_dir, matching_masks[0]),
            'category': category,
            'name': os.path.basename(base)
        })

    print(f"Found {len(pairs)} valid pairs")
    return pairs

# Discover all pairs
file_pairs = find_paired_files(damaged_dir, mask_dir)
if len(file_pairs) == 0:
    print("No valid pairs found! Check your dataset structure.")
    exit()

# Limit number of files for memory purposes
if len(file_pairs) > MAX_SAMPLES:
    print(f"Limiting dataset to first {MAX_SAMPLES} samples (from {len(file_pairs)})")
    file_pairs = file_pairs[:MAX_SAMPLES]
else:
    print(f"Using all {len(file_pairs)} samples (less than {MAX_SAMPLES})")

# Dataset splitting
print("\nMemory optimization")
available_memory_gb = psutil.virtual_memory().available / (1024**3)
total_samples = len(file_pairs)

# Calculate memory requirements
bytes_per_sample = TARGET_SIZE * TARGET_SIZE * 4 * 4  # 4 channels (RGB + mask), 4 bytes per float32
estimated_memory_gb = (total_samples * bytes_per_sample) / (1024**3)

print(f"Available memory: {available_memory_gb:.1f} GB")
print(f"Total samples: {total_samples}")
print(f"Estimated memory needed: {estimated_memory_gb:.1f} GB")

# Auto-adjust target size if needed
if estimated_memory_gb > available_memory_gb * 0.8:
    new_target_size = int(TARGET_SIZE * np.sqrt(available_memory_gb * 0.8 / estimated_memory_gb))
    TARGET_SIZE = max(128, new_target_size) # Do not go lower than 128 for resolution purposes
    print(f"Reduced target size to {TARGET_SIZE} to fit in memory")

# Group by category for balanced splitting (try to get a diverse set of damaged paintings from each category to train the model)
categories = {}
for pair in file_pairs:
    cat = pair['category']
    if cat not in categories:
        categories[cat] = []
    categories[cat].append(pair)

print(f"Categories found: {list(categories.keys())}")

# Stratified splitting - maintaining same ratios
train_pairs = []
val_pairs = []
test_pairs = []

for category, pairs in categories.items():
    n = len(pairs)
    print(f"  {category}: {n} samples")

    # Shuffle within category
    random.shuffle(pairs)

    # Split: 70% train, 15% val, 15% test (same ratios as original)
    train_end = int(0.7 * n)
    val_end = int(0.85 * n)

    train_pairs.extend(pairs[:train_end])
    val_pairs.extend(pairs[train_end:val_end])
    test_pairs.extend(pairs[val_end:])

print(f"\nSplit results (maintaining 70/15/15 ratio):")
print(f"Training samples: {len(train_pairs)} ({len(train_pairs)/total_samples*100:.1f}%)")
print(f"Validation samples: {len(val_pairs)} ({len(val_pairs)/total_samples*100:.1f}%)")
print(f"Test samples: {len(test_pairs)} ({len(test_pairs)/total_samples*100:.1f}%)")

# Preprocessing functions
def preprocess_image_array(img, target_size=TARGET_SIZE, grayscale=False):
    if img is None:
        return None

    if grayscale and len(img.shape) == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    h, w = img.shape[:2]

    # Smart padding to maintain aspect ratio
    if h > w:
        pad_top = 0
        pad_bottom = 0
        pad_left = (h - w) // 2
        pad_right = h - w - pad_left
    else:
        pad_left = 0
        pad_right = 0
        pad_top = (w - h) // 2
        pad_bottom = w - h - pad_top

    # Pad to square
    if len(img.shape) == 3:
        padded = cv2.copyMakeBorder(img, pad_top, pad_bottom, pad_left, pad_right,
                                   cv2.BORDER_CONSTANT, value=[0, 0, 0])
    else:
        padded = cv2.copyMakeBorder(img, pad_top, pad_bottom, pad_left, pad_right,
                                   cv2.BORDER_CONSTANT, value=0)

    # Resize
    resized = cv2.resize(padded, (target_size, target_size), interpolation=cv2.INTER_LANCZOS4)

    # Normalize to [0, 1]
    return resized.astype(np.float32) / 255.0

def load_single_pair_optimized(pair_info):
    try:
        # Load damaged image
        damaged_img = cv2.imread(pair_info['damaged'], cv2.IMREAD_COLOR)
        if damaged_img is None:
            return None, None, f"Failed to load damaged: {pair_info['damaged']}"

        # Load mask
        mask_img = cv2.imread(pair_info['mask'], cv2.IMREAD_GRAYSCALE)
        if mask_img is None:
            return None, None, f"Failed to load mask: {pair_info['mask']}"

        # Preprocess
        damaged_processed = preprocess_image_array(damaged_img, TARGET_SIZE)
        mask_processed = preprocess_image_array(mask_img, TARGET_SIZE, grayscale=True)

        if damaged_processed is None or mask_processed is None:
            return None, None, "Preprocessing failed"

        # Add channel dimension to mask
        mask_processed = mask_processed[..., np.newaxis]

        return damaged_processed, mask_processed, None

    except Exception as e:
        return None, None, f"Error: {str(e)}"


# Optimized revision of batch loading function
def batch_load_optimized(pairs_list, desc="Loading"):
    # Batch loading
    if len(pairs_list) == 0:
        return np.array([]), np.array([])

    print(f"{desc} {len(pairs_list)} samples...")

    images = []
    masks = []
    errors = []

    # Process in smaller batches to manage memory
    batch_size = min(100, len(pairs_list))

    for i in tqdm(range(0, len(pairs_list), batch_size), desc=desc):
        batch_pairs = pairs_list[i:i+batch_size]

        # Parallel processing for current batch
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            futures = [executor.submit(load_single_pair_optimized, pair) for pair in batch_pairs]

            for future in futures:
                try:
                    img, mask, error = future.result()
                    if error:
                        errors.append(error)
                    else:
                        images.append(img)
                        masks.append(mask)
                except Exception as e:
                    errors.append(f"Batch processing error: {str(e)}")

    if errors:
        print(f"{len(errors)} errors encountered")
        if len(errors) <= 5:
            for error in errors:
                print(f"   {error}")

    print(f"Successfully loaded {len(images)} pairs")

    if len(images) == 0:
        return np.array([]), np.array([])

    return np.array(images, dtype=np.float32), np.array(masks, dtype=np.float32)

start_time = time.time()

# Load training data
train_images, train_masks = batch_load_optimized(train_pairs, "Training")
gc.collect()

# Load validation data
val_images, val_masks = batch_load_optimized(val_pairs, "Validation")
gc.collect()

# Load test data
test_images, test_masks = batch_load_optimized(test_pairs, "Test")
gc.collect()

total_time = time.time() - start_time

# Final dataset info
total_loaded = len(train_images) + len(val_images) + len(test_images)
current_memory = psutil.virtual_memory()

# Data quality validation
if len(train_images) > 0:
    if np.any(np.isnan(train_images)) or np.any(np.isinf(train_images)):
        raise ValueError("NaN or infinite values in training images")
    if np.any(np.isnan(train_masks)) or np.any(np.isinf(train_masks)):
        raise ValueError("NaN or infinite values in training masks")

In [None]:
import tensorflow as tf
from tensorflow.keras import layers

# Define attention gate and encoder decoder blocks

def attention_gate(x, g, inter_channels):
    # Linear transformations
    theta_x = layers.Conv2D(inter_channels, 1, padding='same')(x)
    phi_g = layers.Conv2D(inter_channels, 1, padding='same')(g)

    # Upsample gating signal to match x dimensions
    phi_g = layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(phi_g)

    # If dimensions still don't match, use resize layer
    phi_g = layers.Resizing(theta_x.shape[1], theta_x.shape[2])(phi_g)

    # Add and apply activation
    add_xg = layers.Add()([theta_x, phi_g])
    add_xg = layers.Activation('relu')(add_xg)

    # Generate attention coefficients
    psi = layers.Conv2D(1, 1, padding='same')(add_xg)
    psi = layers.Activation('sigmoid')(psi)

    # Apply attention weights
    x_att = layers.Multiply()([x, psi])

    return x_att

def encoder_block(inputs, num_filters):
    # Encoder block: applies two convolutional layers followed by max pooling.
    x = layers.Conv2D(num_filters, 3, padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(num_filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    skip_features = x
    x = layers.MaxPool2D(pool_size=(2, 2))(x)

    return x, skip_features

def decoder_block(inputs, skip_features, num_filters):
    # Upsample
    x = layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(inputs)

    # Apply attention gate
    skip_features = attention_gate(skip_features, inputs, num_filters // 2)

    # Concatenate skip connection
    x = layers.Concatenate()([x, skip_features])

    # Convolutional layers
    x = layers.Conv2D(num_filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(num_filters, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    return x

def build_attention_unet(input_shape=(256, 256, 3)):
    # Build the Unet
    inputs = layers.Input(input_shape)

    # Encoder path
    e1, skip1 = encoder_block(inputs, 64)     # 128x128
    e2, skip2 = encoder_block(e1, 128)        # 64x64
    e3, skip3 = encoder_block(e2, 256)        # 32x32
    e4, skip4 = encoder_block(e3, 512)        # 16x16

    # Bottleneck
    b = layers.Conv2D(1024, 3, padding='same')(e4)
    b = layers.BatchNormalization()(b)
    b = layers.Activation('relu')(b)
    b = layers.Conv2D(1024, 3, padding='same')(b)
    b = layers.BatchNormalization()(b)
    b = layers.Activation('relu')(b)

    # Decoder path with attention gates
    d1 = decoder_block(b, skip4, 512)     # 32x32
    d2 = decoder_block(d1, skip3, 256)    # 64x64
    d3 = decoder_block(d2, skip2, 128)    # 128x128
    d4 = decoder_block(d3, skip1, 64)     # 256x256

    # Final output layer
    outputs = layers.Conv2D(1, 1, activation='sigmoid', padding='same')(d4)

    return tf.keras.Model(inputs, outputs, name='AttentionUNet')

def focal_tversky_loss(y_true, y_pred, alpha=0.7, gamma=0.75):
    y_true = tf.cast(y_true > 0.5, tf.float32)

    # Calculate Tversky components
    tp = tf.reduce_sum(y_true * y_pred)
    fn = tf.reduce_sum(y_true * (1 - y_pred))
    fp = tf.reduce_sum((1 - y_true) * y_pred)

    # Tversky index
    tversky = (tp + 1e-6) / (tp + alpha * fn + (1 - alpha) * fp + 1e-6)

    # Focal Tversky Loss
    return tf.pow(1 - tversky, gamma)

# Build and compile the model
model = build_attention_unet(input_shape=(256, 256, 3))

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=focal_tversky_loss,
    metrics=[
        'accuracy',
        tf.keras.metrics.MeanIoU(num_classes=2),
        tf.keras.metrics.Precision(),
        tf.keras.metrics.Recall()
    ]
)

# Display model summary
model.summary()

In [None]:
# Callbacks for training
# Stop training if validation loss doesnâ€™t improve for 5 epochs, restoring best weights, save the best model based on validation lose, and reduce learning rate in case it plateaus
callbacks = [
    tf.keras.callbacks.EarlyStopping(
        patience=5,
        monitor='val_loss',
        restore_best_weights=True
    ),
    tf.keras.callbacks.ModelCheckpoint(
        'best_model.h5',
        save_best_only=True,
        monitor='val_loss'
    ),
    tf.keras.callbacks.ReduceLROnPlateau(
        factor=0.1,
        patience=3,
        verbose=1
    )
]

In [None]:
# Training configuration
EPOCHS = 40
BATCH_SIZE = 8  # You can adjust this based on your GPU memory

# Verify data shapes
print("\nFinal data shapes verification:")
print(f"Train images: {train_images.shape}, masks: {train_masks.shape}")
print(f"Val images: {val_images.shape}, masks: {val_masks.shape}")

# Training loop
history = model.fit(
    x=train_images,
    y=train_masks,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(val_images, val_masks),
    callbacks=callbacks,
    verbose=1
)


In [None]:
# Next couple of cells are just to deploy the model to HuggingFace

!pip install huggingface_hub -q

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
from huggingface_hub import create_repo

repo_id = "maximiliannl/artwork-damage-detector"
create_repo(repo_id, repo_type="model", exist_ok=True)

RepoUrl('https://huggingface.co/maximiliannl/artwork-damage-detector', endpoint='https://huggingface.co', repo_type='model', repo_id='maximiliannl/artwork-damage-detector')

In [None]:
model.save("artwork_damage_detector.keras")

In [None]:
from huggingface_hub import HfApi

repo_id = "maximiliannl/artwork-damage-detector"

api = HfApi()
api.upload_file(
    path_or_fileobj="artwork_damage_detector.keras",
    path_in_repo="model.keras",
    repo_id=repo_id,
    repo_type="model",
)