In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, callbacks
import numpy as np
import cv2
import os
import h5py
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

In [None]:
# Load Dataset (Modify paths as needed)
IMAGE_SIZE = (256, 256)
DATASET_DIR = "/kaggle/input/figshare-brain-tumor-dataset/dataset/data"  # Folder containing .mat files

In [None]:
def load_mat_files(dataset_dir):
    images, masks = [], []
    files = sorted(os.listdir(dataset_dir))
    
    for file in files:
        if file.endswith(".mat"):
            file_path = os.path.join(dataset_dir, file)
            with h5py.File(file_path, 'r') as data:
                try:
                    img = np.array(data['cjdata']['image'], dtype=np.float32)  # Ensure correct dtype
                    mask = np.array(data['cjdata']['tumorMask'], dtype=np.uint8)
                    
                    img = cv2.resize(img, IMAGE_SIZE) / 255.0  # Normalize image
                    mask = cv2.resize(mask, IMAGE_SIZE)
                    mask = (mask > 0).astype(np.uint8)  # Ensure binary mask
                    
                    images.append(img[..., np.newaxis])  # Add channel dimension
                    masks.append(mask[..., np.newaxis])
                except KeyError as e:
                    print(f"Skipping {file}: KeyError - {e}")
    return np.array(images), np.array(masks)

In [None]:
# Load images and masks
images, masks = load_mat_files(DATASET_DIR)
print(f"Dataset loaded: {len(images)} images, {len(masks)} masks")

In [None]:
# Split into training, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(images, masks, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=1/3, random_state=42)
print(f"Train: {len(X_train)}, Validation: {len(X_val)}, Test: {len(X_test)}")

In [None]:
# U-Net Model
def unet_model(input_shape=(256, 256, 1)):
    inputs = layers.Input(input_shape)
    
    # Encoder
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2,2))(c1)
    
    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2,2))(c2)
    
    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2,2))(c3)
    
    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D((2,2))(c4)
    
    # Bottleneck
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(c5)
    
    # Decoder
    u6 = layers.Conv2DTranspose(512, (2,2), strides=(2,2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(c6)
    
    u7 = layers.Conv2DTranspose(256, (2,2), strides=(2,2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(c7)
    
    u8 = layers.Conv2DTranspose(128, (2,2), strides=(2,2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(c8)
    
    u9 = layers.Conv2DTranspose(64, (2,2), strides=(2,2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1])
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(c9)
    
    outputs = layers.Conv2D(1, (1,1), activation='sigmoid')(c9)
    
    return Model(inputs, outputs)

In [None]:
# Compile Model
model = unet_model()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
# Early Stopping & Learning Rate Scheduler
early_stopping = callbacks.EarlyStopping(patience=5, restore_best_weights=True)
lr_scheduler = callbacks.ReduceLROnPlateau(factor=0.5, patience=3)
checkpoint = callbacks.ModelCheckpoint("best_model.keras", save_best_only=True)

In [None]:
# Train Model
history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=50, batch_size=8, callbacks=[early_stopping, lr_scheduler, checkpoint])


In [None]:
# Save Model
model.save("brain_tumor_unet.keras")

In [None]:
# Training Loss Plot
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.show()

In [None]:
# Model Evaluation
y_pred = model.predict(X_test)
y_pred_bin = (y_pred > 0.5).astype(np.uint8)
y_test_bin = (y_test > 0.5).astype(np.uint8)


In [None]:
# Confusion Matrix & Dice Score
from sklearn.metrics import confusion_matrix

def dice_coefficient(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    return (2. * intersection) / (np.sum(y_true) + np.sum(y_pred) + 1e-7)

dice_score = dice_coefficient(y_test_bin, y_pred_bin)
iou_score = np.sum(y_test_bin * y_pred_bin) / (np.sum(y_test_bin) + np.sum(y_pred_bin) - np.sum(y_test_bin * y_pred_bin) + 1e-7)

print("Dice Score:", dice_score)
print("IoU Score:", iou_score)

cm = confusion_matrix(y_test_bin.flatten(), y_pred_bin.flatten())
print("Confusion Matrix:")
print(cm)


In [None]:
# Evaluate Model
y_pred = model.predict(X_test)
y_pred_bin = (y_pred > 0.5).astype(np.uint8)
y_test_bin = (y_test > 0.5).astype(np.uint8)

def dice_coefficient(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    return (2. * intersection) / (np.sum(y_true) + np.sum(y_pred) + 1e-7)

dice_score = dice_coefficient(y_test_bin, y_pred_bin)
print("Dice Score:", dice_score)