In [2]:
%%writefile app.py
import streamlit as st
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import GaussianNoise

# Register custom layer
custom_objects = {'GaussianNoise': GaussianNoise}

# Load model
model = tf.keras.models.load_model('EN_B0v9999.keras', custom_objects=custom_objects)
class_names = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']

def format_text_list(items):
    return [i.strip().capitalize() for i in items if isinstance(i, str)]


# Title
st.title("🧠 Dementia MRI Classifier")
st.subheader("Welcome to BrainCheckBot! Analyze your MRI scan.")

# Inputs
name = st.text_input("👤 Patient’s Name")
age = st.text_input("📆 Age (Optional)")
gender = st.radio("⚧️ Gender", ["Male", "Female", "Prefer not to say"], horizontal=True)

symptom_options = [
    "Memory loss", "Confusion", "Headaches", "Dizziness",
    "Mood changes", "Difficulty concentrating", "Vision issues", "Personality changes", "Other",
]
symptoms = st.multiselect("🧠 Symptoms", options=symptom_options)

if "Other" in symptoms:
    custom_symptom = st.text_input("✍️ Please specify other symptom")
    if custom_symptom:
        symptoms.append(custom_symptom)

reason_options = [
    "Routine check", "Family history", "Head trauma",
    "Physician referral", "Follow-up scan", "Clinical trial screening", "Other",
]
reason = st.multiselect("📄 Reason for Scan", options=reason_options)

if "Other" in reason:
    custom_reason = st.text_input("✍️ Please specify other symptom")
    if custom_reason:
        reason.append(custom_reason)

uploaded_file = st.file_uploader("📸 Upload an MRI scan...", type=["png", "jpg", "jpeg"])

if uploaded_file and name:
    # Clean and format inputs
    name = name.title().strip()
    gender = gender.capitalize().strip()
    symptoms = format_text_list(symptoms)
    reason = format_text_list(reason)
    # Show uploaded image
    image = Image.open(uploaded_file)
    st.image(image, caption='🧠 MRI Image Preview', use_column_width=True)

    # Preprocess
    img = image.resize((224, 224))
    img = np.array(img).astype('float32')
    img = tf.keras.applications.efficientnet.preprocess_input(img)
    img = np.expand_dims(img, axis=0)

    # Predict
    prediction = model.predict(img)
    predicted_class = np.argmax(prediction, axis=1)[0]
    predicted_label = class_names[predicted_class]

    # Display result
    st.markdown(f"### ✅ Predicted Class: **{predicted_label}**")
    st.write("🔍 Raw Prediction Vector:", prediction[0])




# Subtle info box
with st.expander("ℹ️ What do the classes mean?"):
    st.markdown("""
    - **NonDemented**: No signs of dementia detected.  
    - **VeryMildDemented**: Very early signs of cognitive decline.  
    - **MildDemented**: Noticeable cognitive decline, may need medical follow-up.  
    - **ModerateDemented**: More severe symptoms that require urgent clinical evaluation.
    """)



from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.utils import ImageReader
import io
from datetime import datetime

# Doctor’s advice based on prediction
doctor_notes = {
    "NonDemented": "No signs of dementia detected. Maintain regular checkups and a healthy lifestyle.",
    "VeryMildDemented": "Very mild cognitive symptoms observed. Recommend regular monitoring and lifestyle changes.",
    "MildDemented": "Mild dementia detected. Clinical evaluation and cognitive therapies advised.",
    "ModerateDemented": "Moderate dementia identified. Consult a neurologist for personalized treatment planning."
}

def generate_pdf():
    buffer = io.BytesIO()
    c = canvas.Canvas(buffer, pagesize=letter)

    # Header
    c.setFont("Helvetica-Bold", 18)
    c.drawCentredString(300, 750, "MRI Report")
    c.setFont("Helvetica", 10)
    c.drawCentredString(300, 735, f"Generated on {datetime.now().strftime('%B %d, %Y at %I:%M %p')}")

    y = 700
    line_height = 18

    # Section 1: Patient Demographics
    c.setFont("Helvetica-Bold", 12)
    c.drawString(50, y, "Patient Demographics")
    y -= line_height
    c.setFont("Helvetica", 11)
    c.drawString(70, y, f"Name: {name}")
    y -= line_height
    if age:
        c.drawString(70, y, f"Age: {age}")
        y -= line_height
    c.drawString(70, y, f"Gender: {gender}")
    y -= line_height + 10  # Extra space

    # Section 2: Medical Examination Findings
    c.setFont("Helvetica-Bold", 12)
    c.drawString(50, y, "Medical Examination Findings")
    y -= line_height
    c.setFont("Helvetica", 11)
    c.drawString(70, y, "Symptoms:")
    y -= line_height
    if symptoms:
        for s in symptoms:
            c.drawString(90, y, f"• {s}")
            y -= line_height
    else:
        c.drawString(90, y, "• None reported")
        y -= line_height

    y -= 5
    c.drawString(70, y, "Reason for Scan:")
    y -= line_height
    if reason:
        for r in reason:
            c.drawString(90, y, f"• {r}")
            y -= line_height
    else:
        c.drawString(90, y, "• Not specified")
        y -= line_height

    # Gap before diagnosis and image
    y -= 20

    # Section 3: Diagnostic Results
    c.setFont("Helvetica-Bold", 12)
    c.drawString(50, y, "Diagnostic Results")
    y -= line_height
    c.setFont("Helvetica", 11)
    c.drawString(70, y, f"Predicted Diagnosis: {predicted_label}")

    # MRI Image → draw at fixed right-side position
    try:
        image_rgb = image.convert("RGB")
        img_io = io.BytesIO()
        image_rgb.save(img_io, format='PNG')
        img_io.seek(0)
        c.drawImage(ImageReader(img_io), 370, y - 20, width=200, height=200)
    except:
        c.drawString(370, y - 20, "[Image error]")

    y -= 40  # Ensure enough space below image

    # Section 4: Doctor’s Recommendation
    y -= 20
    c.setFont("Helvetica-Bold", 12)
    c.drawString(50, y, "Doctor’s Recommendation")
    y -= line_height
    c.setFont("Helvetica", 11)
    c.drawString(70, y, doctor_notes[predicted_label])

    # Footer
    c.setFont("Helvetica-Oblique", 9)
    c.drawString(50, 40, "Note: This report does not replace professional medical advice.")

    c.showPage()
    c.save()
    buffer.seek(0)
    return buffer

if st.button("📄 Generate PDF Report"):
    pdf = generate_pdf()
    st.download_button(
        label="⬇️ Download Report",
        data=pdf,
        file_name=f"{name}_BrainCheckBot_Report.pdf",
        mime="application/pdf"
    )

# Footer with model info
st.markdown("""
<hr style="margin-top: 2em;">
<div style='text-align: center; font-size: 0.85em; color: gray;'>
Model: EfficientNetB0 | Project II 2025
</div>
""", unsafe_allow_html=True)



Overwriting app.py
