# Brain Tumor Segmentation with U-Net

This notebook implements a U-Net model for brain tumor segmentation using MRI images. It follows a structured pipeline: dataset setup, exploratory data analysis (EDA), plotting, data augmentation, preprocessing, image loading, data splitting, model building, training, and evaluation.

**Dependencies**: TensorFlow, NumPy, Matplotlib, OpenCV, Scikit-learn

**Dataset**: Assumes MRI images and binary masks in directories (e.g., `train/images`, `train/masks`). Update paths to match your dataset.

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os
import matplotlib.pyplot as plt
import cv2
from sklearn.model_selection import train_test_split
%matplotlib inline

## 1. Dataset

Define paths to the dataset containing MRI images and corresponding segmentation masks. The dataset should have subdirectories for images and masks.

In [None]:
# Dataset paths (adjust based on your dataset structure)
data_dir = 'path/to/dataset'
image_dir = os.path.join(data_dir, 'images')
mask_dir = os.path.join(data_dir, 'masks')

# Image parameters
IMG_HEIGHT = 128
IMG_WIDTH = 128
BATCH_SIZE = 16

## 2. Exploratory Data Analysis (EDA)

Analyze the dataset to understand its structure, including the number of images and masks, and check for consistency.

In [None]:
# Get list of image and mask files
image_files = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.jpg', '.png'))]
mask_files = [os.path.join(mask_dir, f) for f in os.listdir(mask_dir) if f.endswith(('.jpg', '.png'))]

# Check dataset size and consistency
print(f'Number of images: {len(image_files)}')
print(f'Number of masks: {len(mask_files)}')
assert len(image_files) == len(mask_files), 'Number of images and masks must match'

# Check sample image and mask dimensions
sample_img = cv2.imread(image_files[0])
sample_mask = cv2.imread(mask_files[0], cv2.IMREAD_GRAYSCALE)
print(f'Sample image shape: {sample_img.shape}')
print(f'Sample mask shape: {sample_mask.shape}')

## 3. Plot

Visualize sample images and their corresponding masks to understand the data distribution and verify alignment.

In [None]:
# Plot sample images and masks
plt.figure(figsize=(10, 5))
for i in range(3):
    img = cv2.imread(image_files[i])
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(mask_files[i], cv2.IMREAD_GRAYSCALE)
    
    plt.subplot(2, 3, i+1)
    plt.imshow(img)
    plt.title('MRI Image')
    plt.axis('off')
    
    plt.subplot(2, 3, i+4)
    plt.imshow(mask, cmap='gray')
    plt.title('Mask')
    plt.axis('off')
plt.tight_layout()
plt.show()

## 4. Data Augmentation

Apply data augmentation to the training images and masks to increase dataset diversity and prevent overfitting.

In [None]:
def create_data_generator(img_dir, mask_dir, target_size, batch_size, augment=False):
    if augment:
        img_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest'
        )
        mask_datagen = ImageDataGenerator(
            rescale=1./255,
            rotation_range=20,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest'
        )
    else:
        img_datagen = ImageDataGenerator(rescale=1./255)
        mask_datagen = ImageDataGenerator(rescale=1./255)

    img_generator = img_datagen.flow_from_directory(
        img_dir,
        target_size=target_size,
        batch_size=batch_size,
        class_mode=None,
        color_mode='rgb',
        shuffle=True,
        seed=42
    )
    mask_generator = mask_datagen.flow_from_directory(
        mask_dir,
        target_size=target_size,
        batch_size=batch_size,
        class_mode=None,
        color_mode='grayscale',
        shuffle=True,
        seed=42
    )
    return zip(img_generator, mask_generator)

## 5. Data Preprocessing

Apply preprocessing steps: convert BGR to grayscale, apply GaussianBlur, threshold, erode, dilate, and find contours to enhance images and masks.

In [None]:
def preprocess_image(img_path, target_size=(IMG_HEIGHT, IMG_WIDTH)):
    img = cv2.imread(img_path)
    # Convert BGR to grayscale
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # Apply GaussianBlur
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    # Apply thresholding
    _, thresh = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
    # Erode and dilate
    kernel = np.ones((3, 3), np.uint8)
    eroded = cv2.erode(thresh, kernel, iterations=1)
    dilated = cv2.dilate(eroded, kernel, iterations=1)
    # Find contours
    contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    # Draw contours on original image (optional, for visualization)
    contour_img = img.copy()
    cv2.drawContours(contour_img, contours, -1, (0, 255, 0), 2)
    # Resize to target size
    processed_img = cv2.resize(dilated, target_size)
    # Convert back to 3 channels for model input
    processed_img = cv2.cvtColor(processed_img, cv2.COLOR_GRAY2RGB)
    processed_img = processed_img / 255.0
    return processed_img

def preprocess_mask(mask_path, target_size=(IMG_HEIGHT, IMG_WIDTH)):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    # Apply thresholding to ensure binary mask
    _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    # Resize to target size
    mask = cv2.resize(mask, target_size)
    mask = mask / 255.0
    return mask[..., np.newaxis]

## 6. Image Loading

Load and preprocess images and masks using the defined preprocessing functions.

In [None]:
# Load images and masks
images = np.array([preprocess_image(f) for f in image_files])
masks = np.array([preprocess_mask(f) for f in mask_files])

print(f'Loaded images shape: {images.shape}')
print(f'Loaded masks shape: {masks.shape}')

## 7. Data Splitting (Train, Test, Validation)

Split the dataset into training, validation, and test sets (e.g., 70% train, 15% validation, 15% test).

In [None]:
# Split data
X_train, X_temp, y_train, y_temp = train_test_split(images, masks, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

print(f'Training set: {X_train.shape[0]} images')
print(f'Validation set: {X_val.shape[0]} images')
print(f'Test set: {X_test.shape[0]} images')

## 8. U-Net Model Building and Training

Define and train the U-Net model for segmentation.

In [None]:
def unet_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)):
    inputs = Input(input_shape)
    
    # Encoder
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)
    
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)
    
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)
    
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = MaxPooling2D((2, 2))(c4)
    
    # Bottleneck
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)
    
    # Decoder
    u6 = UpSampling2D((2, 2))(c5)
    u6 = Concatenate()([u6, c4])
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
    
    u7 = UpSampling2D((2, 2))(c6)
    u7 = Concatenate()([u7, c3])
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
    
    u8 = UpSampling2D((2, 2))(c7)
    u8 = Concatenate()([u8, c2])
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
    
    u9 = UpSampling2D((2, 2))(c8)
    u9 = Concatenate()([u9, c1])
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
    
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c9)
    
    return Model(inputs, outputs)

# Create and compile model
model = unet_model()
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Model summary
model.summary()

# Train model
EPOCHS = 20
history = model.fit(
    X_train, y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val, y_val)
)

## 9. Evaluation and Visualization

Evaluate the model on the test set and visualize training/validation metrics and sample predictions.

In [None]:
# Evaluate model
test_loss, test_accuracy = model.evaluate(X_test, y_test)
print(f'Test Accuracy: {test_accuracy * 100:.2f}%')

# Save model
model.save('brain_tumor_unet_model.h5')

# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Visualize sample predictions
predictions = model.predict(X_test[:3])
plt.figure(figsize=(10, 5))
for i in range(3):
    plt.subplot(3, 3, i+1)
    plt.imshow(X_test[i])
    plt.title('Input Image')
    plt.axis('off')
    
    plt.subplot(3, 3, i+4)
    plt.imshow(y_test[i].squeeze(), cmap='gray')
    plt.title('Ground Truth')
    plt.axis('off')
    
    plt.subplot(3, 3, i+7)
    plt.imshow(predictions[i].squeeze(), cmap='gray')
    plt.title('Prediction')
    plt.axis('off')
plt.tight_layout()
plt.show()