In [2]:
import tqdm as notebook_tqdm
import gradio as gr
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import cv2
from PIL import Image
import os
import shutil

In [None]:
# --- Define Constants ---
CLASSIFICATION_MODEL_PATH = 'brain_tumor_vit_model_5extra.h5'
SEGMENTATION_MODEL_PATH = 'seg_model_v2.h5'
CLASS_NAMES = ['glioma', 'meningioma', 'notumor', 'pituitary']
IMAGE_SIZE = (224, 224)
NUM_CLASSES = len(CLASS_NAMES)

# --- Utility Functions ---
class Patches(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super(Patches, self).__init__(**kwargs)
        self.patch_size = patch_size
    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    def get_config(self):
        config = super().get_config()
        config.update({'patch_size': self.patch_size})
        return config

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim, **kwargs):
        super(PatchEncoder, self).__init__(**kwargs)
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded
    def get_config(self):
        config = super().get_config()
        config.update({'num_patches': self.num_patches, 'projection_dim': self.projection_dim})
        return config


def create_dummy_models_if_not_exist():
    """Checks for model files and creates dummy ones if they are missing."""

    # --- Classification Model (ViT) ---
    if not os.path.exists(CLASSIFICATION_MODEL_PATH):
        print(f"Placeholder classification model not found. Creating a dummy model at: {CLASSIFICATION_MODEL_PATH}")

        patch_size = 16
        num_patches = (IMAGE_SIZE[0] // patch_size) ** 2
        projection_dim = 64

        inputs = layers.Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1))
        patches = Patches(patch_size)(inputs)
        encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
        flattened = layers.Flatten()(encoded_patches)
        outputs = layers.Dense(NUM_CLASSES, activation="softmax")(flattened)

        dummy_cls_model = keras.Model(inputs=inputs, outputs=outputs)
        dummy_cls_model.save(CLASSIFICATION_MODEL_PATH)
        print("Dummy classification model created.")

    # --- Segmentation Model ---
    if not os.path.exists(SEGMENTATION_MODEL_PATH):
        print(f"Placeholder segmentation model not found. Creating a dummy model at: {SEGMENTATION_MODEL_PATH}")

        inputs = layers.Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 1))
        conv1 = layers.Conv2D(8, (3, 3), activation='relu', padding='same')(inputs)
        outputs = layers.Conv2D(1, (1, 1), activation='sigmoid', padding='same')(conv1)

        dummy_seg_model = keras.Model(inputs=inputs, outputs=outputs)
        dummy_seg_model.save(SEGMENTATION_MODEL_PATH)
        print("Dummy segmentation model created.")

In [None]:
# --- Loading the Model ---
create_dummy_models_if_not_exist()

print("\nLoading models... This may take a moment.")
try:
    classification_model = tf.keras.models.load_model(
        CLASSIFICATION_MODEL_PATH,
        custom_objects={'Patches': Patches, 'PatchEncoder': PatchEncoder}
    )
    segmentation_model = tf.keras.models.load_model(SEGMENTATION_MODEL_PATH, compile=False)
    print("Models loaded successfully.")
except Exception as e:
    print(f"Error loading models: {e}")
    exit()


Loading models... This may take a moment.





Models loaded successfully.


In [None]:
# --- Preprocessing ---
def predict(input_image):
    """
    Takes a user-uploaded image, runs it through both models,
    and returns the classification and segmentation results.
    """
    if input_image is None:
        return None, None

    image = Image.fromarray(input_image)

    # Classification
    img_gray_cls = image.convert('L')
    img_resized_cls = img_gray_cls.resize(IMAGE_SIZE, Image.Resampling.LANCZOS)
    img_np_cls = np.array(img_resized_cls)
    img_equalized_cls = cv2.equalizeHist(img_np_cls)
    img_final_cls = img_equalized_cls.astype('float32') / 255.0
    img_final_cls = np.expand_dims(img_final_cls, axis=-1)
    img_final_cls = np.expand_dims(img_final_cls, axis=0)
    class_prediction = classification_model.predict(img_final_cls)[0]
    confidences = {CLASS_NAMES[i]: float(class_prediction[i]) for i in range(len(CLASS_NAMES))}

    # Segmentation
    img_rgb_seg = image.convert('RGB')
    img_resized_seg = img_rgb_seg.resize(IMAGE_SIZE, Image.Resampling.LANCZOS)
    img_np_seg = np.array(img_resized_seg)
    img_final_seg = img_np_seg.astype('float32') / 255.0
    img_final_seg = np.expand_dims(img_final_seg, axis=0)
    seg_mask = segmentation_model.predict(img_final_seg)[0]
    seg_mask = np.squeeze((seg_mask * 255).astype(np.uint8))

    return confidences, seg_mask

In [8]:
# --- 3. Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Brain Tumor Analysis: Classification & Segmentation
        Upload a MRI brain scan to classify the tumor type and generate a segmentation mask.
        """
    )
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type="numpy", label="Upload MRI Image")
            submit_btn = gr.Button("Analyze Image", variant="primary")
        with gr.Column():
            output_label = gr.Label(num_top_classes=4, label="Classification Results")
            output_segmentation = gr.Image(label="Segmentation Mask")

    submit_btn.click(
        fn=predict,
        inputs=input_image,
        outputs=[output_label, output_segmentation]
    )

    gr.Examples(
        examples=[],
        inputs=input_image,
        outputs=[output_label, output_segmentation],
        fn=predict,
        cache_examples=True
    )

In [9]:
# --- 4. Launch the App ---
if __name__ == "__main__":
    print("Launching Gradio app...")
    demo.launch(debug=True)

Launching Gradio app...
* Running on local URL:  http://127.0.0.1:7860
Caching examples at: 'd:\Research\Barin Tumor MRI\Brain Tumor Detector\.gradio\cached_examples\19'
* To create a public link, set `share=True` in `launch()`.


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 994ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 127ms/step
Keyboard interruption in main thread... closing server.
