# RAPUNet Training for Image Segmentation

This notebook implements training and evaluation of RAPUNet for image segmentation.

In [None]:
import tensorflow as tf
import albumentations as albu
import numpy as np
import gc
import os
from datetime import datetime
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from sklearn.metrics import jaccard_score, precision_score, recall_score, accuracy_score, f1_score

# Import custom modules
from ModelArchitecture.DiceLoss import dice_metric_loss
from ModelArchitecture import RAPUNet
from CustomLayers import ImageLoader2D
import tensorflow_addons as tfa

In [None]:
# Configuration
IMG_SIZE = 352
BATCH_SIZE = 8
EPOCHS = 100
FILTERS = 17

# Data paths
TRAIN_PATH = "train/"
VAL_PATH = "val/"
TEST_PATH = "test/"

# Learning rate schedule
starter_learning_rate = 1e-4
end_learning_rate = 1e-6
decay_steps = 1000
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
    starter_learning_rate,
    decay_steps,
    end_learning_rate,
    power=0.2
)

# Optimizer
optimizer = tfa.optimizers.AdamW(
    learning_rate=1e-4,
    weight_decay=learning_rate_fn
)

In [None]:
# Load training and validation data
print("Loading training data...")
x_train, y_train = ImageLoader2D.load_data(IMG_SIZE, IMG_SIZE, -1, 'jpg', TRAIN_PATH)

print("Loading validation data...")
x_valid, y_valid = ImageLoader2D.load_data(IMG_SIZE, IMG_SIZE, -1, 'jpg', VAL_PATH)

print("Training shapes:", x_train.shape, y_train.shape)
print("Validation shapes:", x_valid.shape, y_valid.shape)

In [None]:
# Data augmentation
aug_train = albu.Compose([
    albu.HorizontalFlip(),
    albu.VerticalFlip(),
    albu.ColorJitter(brightness=(0.6,1.6), contrast=0.2, saturation=0.1, hue=0.01, always_apply=True),
    albu.Affine(scale=(0.5,1.5), translate_percent=(-0.125,0.125), rotate=(-180,180), shear=(-22.5,22), always_apply=True),
])

def augment_batch(images, masks):
    images_aug = []
    masks_aug = []
    
    for image, mask in zip(images, masks):
        augmented = aug_train(image=image, mask=mask)
        images_aug.append(augmented['image'])
        masks_aug.append(augmented['mask'])
        
    return np.array(images_aug), np.array(masks_aug)

In [None]:
# Create and compile model
model = RAPUNet.create_model(
    img_height=IMG_SIZE,
    img_width=IMG_SIZE, 
    input_chanels=3,
    out_classes=1,
    starting_filters=FILTERS
)

model.compile(
    optimizer=optimizer,
    loss=dice_metric_loss
)

model.summary()

In [None]:
# Callbacks
checkpoint = ModelCheckpoint(
    'best_model.h5',
    monitor='val_loss',
    save_best_only=True,
    mode='min'
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=20,
    restore_best_weights=True
)

csv_logger = CSVLogger('training_log.csv')

callbacks = [checkpoint, early_stopping, csv_logger]

In [None]:
# Training loop
history = model.fit(
    x=x_train,
    y=y_train,
    validation_data=(x_valid, y_valid),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=callbacks,
    verbose=1
)

In [None]:
# Load test data and generate predictions
print("Loading test data...")
test_images = glob.glob(os.path.join(TEST_PATH, "*.jpg"))

# Load and preprocess test images
x_test = []
for img_path in test_images:
    img = cv2.imread(img_path)
    img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
    img = img / 255.0
    x_test.append(img)

x_test = np.array(x_test)

# Generate predictions
predictions = model.predict(x_test, batch_size=BATCH_SIZE)

# Save predictions
os.makedirs("predictions", exist_ok=True)
for i, pred in enumerate(predictions):
    pred = (pred > 0.5).astype(np.uint8) * 255
    cv2.imwrite(f"predictions/pred_{i:04d}.png", pred)