# 🏥 **Osteoarthritis Classification Clinical Deployment System**
*Complete deployment solution with UI and LLM-powered treatment recommendations*

---

# ⚠️ **HOW TO RUN THIS APPLICATION**

**EASIEST METHOD: Double-click `start_clinical_app.bat` in the main folder**

**Alternative methods:**
- Command line: `streamlit run clinical_app_standalone.py`
- Notebook: `streamlit run notebooks/05_Deployment_and_Clinical_UI.ipynb`

**⚠️ Do NOT execute notebook cells directly** - this notebook is for code viewing and documentation.

---

# **Purpose**

This notebook creates a comprehensive clinical deployment system for osteoarthritis severity classification. It includes:

1. **Image Classification UI** - Upload knee X-rays and get severity predictions
2. **LLM Treatment Planner** - AI-powered treatment recommendations based on:
   - X-ray classification results
   - Patient demographics (age, gender)
   - Symptoms and expectations
   - Evidence-based medical guidelines

3. **Clinical Dashboard** - Track predictions, generate reports, and manage patient data

---

# **Key Features**

**AI-Powered Diagnosis:**
- Trained PyTorch models for 5-class osteoarthritis severity
- Confidence scores and uncertainty quantification
- Ensemble predictions for improved accuracy

**Holistic Treatment Planning:**
- Evidence-based recommendations from peer-reviewed literature
- Personalized treatment plans considering patient factors
- Non-pharmacological and pharmacological options
- Lifestyle and exercise recommendations

**Modern Web Interface:**
- Streamlit-based responsive UI
- Real-time predictions
- Secure patient data handling
- Professional medical interface design

**Clinical Integration:**
- Export capabilities for medical records
- Batch processing for multiple patients
- Audit trail and logging
- HIPAA-compliant design principles

---

# **Clinical Evidence Base**

This system incorporates guidelines from:
- American College of Rheumatology (ACR)
- Osteoarthritis Research Society International (OARSI)
- European League Against Rheumatism (EULAR)
- Recent systematic reviews and meta-analyses
- Conservative and holistic treatment approaches

---


In [1]:
# === Deployment Setup and Imports ===
import os, sys, json, datetime
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T

import streamlit as st
import plotly.express as px
import plotly.graph_objects as go

# Optional LLM support
try:
    import openai
    OPENAI_AVAILABLE = True
except ImportError:
    OPENAI_AVAILABLE = False

# Load OpenAI API key from Streamlit secrets
def get_openai_key():
    """Get OpenAI API key from Streamlit secrets or environment."""
    try:
        # Try Streamlit secrets first (for app deployment)
        if hasattr(st, 'secrets') and 'OPENAI_API_KEY' in st.secrets:
            return st.secrets['OPENAI_API_KEY']
    except:
        pass
    
    # Fallback to environment variable
    import os
    return os.getenv('OPENAI_API_KEY')

# Set up OpenAI if available
OPENAI_API_KEY = None
if OPENAI_AVAILABLE:
    try:
        OPENAI_API_KEY = get_openai_key()
        if OPENAI_API_KEY:
            openai.api_key = OPENAI_API_KEY
    except Exception as e:
        pass

def get_repo_root(marker_files=(".git", "pyproject.toml", "requirements.txt")) -> Path:
    """Walk upward until we hit a repo marker; fallback to '..'."""
    cur = Path.cwd().resolve()
    for parent in [cur, *cur.parents]:
        if any((parent / m).exists() for m in marker_files):
            return parent
    return Path("..").resolve()

REPO_ROOT = get_repo_root()
REPO_NAME = REPO_ROOT.name
sys.path.append(str(REPO_ROOT))

print("Clinical Deployment System Setup Complete")
print(f"Repository      : {REPO_NAME}")
print(f"OpenAI Available: {OPENAI_AVAILABLE}")
print(f"PyTorch Version : {torch.__version__}")
print(f"CUDA Available  : {torch.cuda.is_available()}")

Clinical Deployment System Setup Complete
Repository      : osteoarthritis-severity
OpenAI Available: False
PyTorch Version : 2.5.1
CUDA Available  : True


In [2]:
# === Model Deployment Classes ===

import torchvision.models as tv_models

# ---- Helper: map cfg['model_name'] to a constructor ----
MODEL_BUILDERS = {
    "resnet50":        lambda cfg: _build_resnet50(cfg),
    "resnet":          lambda cfg: _build_resnet50(cfg),
    "densenet121":     lambda cfg: _build_densenet121(cfg),
    "densenet":        lambda cfg: _build_densenet121(cfg),
    "efficientnet_b0": lambda cfg: _build_efficientnet_b0(cfg),
    "efficientnet":    lambda cfg: _build_efficientnet_b0(cfg),
    "convnext_tiny":   lambda cfg: _build_convnext_tiny(cfg),
    "convnext":        lambda cfg: _build_convnext_tiny(cfg),
    "regnet_y_800mf":  lambda cfg: _build_regnet_y_800mf(cfg),
    "regnet":          lambda cfg: _build_regnet_y_800mf(cfg),
}

def _build_resnet50(cfg):
    m = tv_models.resnet50(weights=None)
    hidden_dim = cfg.get("hidden_dim", 384)
    num_classes = cfg.get("num_classes", 5)
    dropout_rate = cfg.get("dropout", 0.1)
    
    # Store original fc in_features before replacing
    in_features = m.fc.in_features
    
    # Build fc to match training exactly: [dropout, linear_hidden, relu, dropout, linear_output]
    m.fc = nn.Sequential(
        nn.Dropout(dropout_rate),                     # fc[0]
        nn.Linear(in_features, hidden_dim),           # fc[1]  
        nn.ReLU(),                                    # fc[2]
        nn.Dropout(dropout_rate),                     # fc[3]
        nn.Linear(hidden_dim, num_classes),           # fc[4]
    )
    return m

def _build_densenet121(cfg):
    m = tv_models.densenet121(weights=None)
    hidden_dim = cfg.get("hidden_dim", 512)
    num_classes = cfg.get("num_classes", 5)
    dropout_rate = cfg.get("dropout", 0.4)
    
    # Store original classifier in_features before replacing
    in_features = m.classifier.in_features
    
    # Build classifier to match training exactly: [dropout, linear_hidden, relu, dropout, linear_output]
    m.classifier = nn.Sequential(
        nn.Dropout(dropout_rate),                     # classifier[0]
        nn.Linear(in_features, hidden_dim),           # classifier[1]
        nn.ReLU(),                                    # classifier[2]
        nn.Dropout(dropout_rate),                     # classifier[3]
        nn.Linear(hidden_dim, num_classes),           # classifier[4]
    )
    return m

def _build_efficientnet_b0(cfg):
    m = tv_models.efficientnet_b0(weights=None)
    hidden_dim = cfg.get("hidden_dim", 384)
    num_classes = cfg.get("num_classes", 5)
    dropout_rate = cfg.get("dropout", 0.4)
    
    # Store original classifier in_features before replacing
    in_features = m.classifier[1].in_features
    
    # Build classifier to match training exactly: [dropout, linear_hidden, relu, dropout, linear_output]
    m.classifier = nn.Sequential(
        nn.Dropout(dropout_rate),                     # classifier[0]
        nn.Linear(in_features, hidden_dim),           # classifier[1] 
        nn.ReLU(),                                    # classifier[2]
        nn.Dropout(dropout_rate),                     # classifier[3]
        nn.Linear(hidden_dim, num_classes),           # classifier[4]
    )
    return m

def _build_convnext_tiny(cfg):
    m = tv_models.convnext_tiny(weights=None)
    hidden_dim = cfg.get("hidden_dim", 512)
    num_classes = cfg.get("num_classes", 5)
    dropout_rate = cfg.get("dropout", 0.2)
    
    # Store original classifier in_features before replacing last layer
    in_features = m.classifier[-1].in_features
    
    # Build classifier head to match training exactly: [dropout, linear_hidden, relu, dropout, linear_output]
    head = nn.Sequential(
        nn.Dropout(dropout_rate),                     # head[0]
        nn.Linear(in_features, hidden_dim),           # head[1]
        nn.ReLU(),                                    # head[2]
        nn.Dropout(dropout_rate),                     # head[3]
        nn.Linear(hidden_dim, num_classes),           # head[4]
    )
    
    # Replace only the last layer (like training code does)
    m.classifier[-1] = head
    return m

def _build_regnet_y_800mf(cfg):
    m = tv_models.regnet_y_800mf(weights=None)
    hidden_dim = cfg.get("hidden_dim", 768)
    num_classes = cfg.get("num_classes", 5)
    dropout_rate = cfg.get("dropout", 0.2)
    
    # Store original fc in_features before replacing
    in_features = m.fc.in_features
    
    # Build fc to match training exactly: [dropout, linear_hidden, relu, dropout, linear_output]
    m.fc = nn.Sequential(
        nn.Dropout(dropout_rate),                     # fc[0]
        nn.Linear(in_features, hidden_dim),           # fc[1]
        nn.ReLU(),                                    # fc[2]
        nn.Dropout(dropout_rate),                     # fc[3]
        nn.Linear(hidden_dim, num_classes),           # fc[4]
    )
    return m


class EnsembleModel(nn.Module):
    """Weighted-logit ensemble used in deployment (same as training)."""
    def __init__(self, models: List[nn.Module], weights: np.ndarray, device: str):
        super().__init__()
        self.device  = device
        self.models  = nn.ModuleList([m.to(device) for m in models])
        self.register_buffer("weights", torch.tensor(weights, dtype=torch.float32, device=device))

    def forward(self, x, return_logits: bool = True):
        x = x.to(self.device)
        logits_list = []
        for m in self.models:
            m.eval()
            with torch.no_grad():
                logits_list.append(m(x))
        stacked = torch.stack(logits_list, dim=0)                # [M, B, C]
        w       = self.weights.view(-1, 1, 1)                    # [M, 1, 1]
        ens_logits = (stacked * w).sum(dim=0)                    # [B, C]
        return ens_logits if return_logits else F.softmax(ens_logits, dim=1)


class IndividualModelWrapper(nn.Module):
    """Wrapper for individual models to match ensemble interface."""
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model
    
    def forward(self, x, return_logits: bool = True):
        logits = self.model(x)  # Individual models just return logits
        return logits if return_logits else F.softmax(logits, dim=1)


class TemperatureScaling(nn.Module):
    """Wrapper to apply learned temperature on logits (inference only)."""
    def __init__(self, model: nn.Module, temperature: float, device: str):
        super().__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.tensor([temperature], device=device), requires_grad=False)

    def forward(self, x, return_logits: bool = True):
        logits = self.model(x, return_logits=True)
        scaled = logits / self.temperature
        return scaled if return_logits else F.softmax(scaled, dim=1)


class OsteoarthritisClassificationModel:
    """Production-ready osteoarthritis classification model (calibrated ensemble)."""
    def __init__(self, ckpt_path: Optional[str] = None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.class_names = ['Normal', 'Doubtful', 'Mild', 'Moderate', 'Severe']  # match training list
        self.class_descriptions = {
            'Normal'  : 'No signs of osteoarthritis',
            'Doubtful': 'Possible early osteoarthritis changes',
            'Mild'    : 'Mild osteoarthritis with minor joint changes',
            'Moderate': 'Moderate osteoarthritis with clear joint degeneration',
            'Severe'  : 'Severe osteoarthritis with significant joint damage',
        }

        if ckpt_path and Path(ckpt_path).exists():
            self.model = self._load_calibrated_ensemble(ckpt_path)
        else:
            st.warning("No trained ensemble found. Using demo ResNet-50 model.")
            self.model = self._create_demo_model(num_classes=len(self.class_names))

        self.model.eval()

        # Same normalization as training with grayscale conversion
        self.transform = T.Compose([
            T.Resize((224, 224)),
            T.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel RGB
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
                        std =[0.229, 0.224, 0.225])
        ])

    # ---------------- Internal helpers ----------------
    def _create_demo_model(self, num_classes: int):
        m = tv_models.resnet50(weights="IMAGENET1K_V2")
        m.fc = nn.Linear(m.fc.in_features, num_classes)
        return m.to(self.device)

    def _load_calibrated_ensemble(self, ckpt_path: str) -> nn.Module:
        ckpt = torch.load(ckpt_path, map_location=self.device, weights_only=False)
        
        # Check if this is an individual model or ensemble
        model_type = ckpt.get("model_type", "calibrated_ensemble")
        deployment_strategy = ckpt.get("deployment_strategy", "best_ensemble")
        ensemble_strategy = ckpt.get("ensemble_strategy", "unknown")
        
        # Rebuild base nets from saved configs
        model_cfgs = ckpt["model_configs"]
        base_models = []
        for cfg in model_cfgs:
            name = cfg["model_name"]
            builder = MODEL_BUILDERS.get(name.lower())
            if builder is None:
                raise ValueError(f"Unknown model_name '{name}' in checkpoint.")
            net = builder(cfg)
            # state_dicts were saved in same order:
            sd = ckpt["individual_models"][len(base_models)]
            net.load_state_dict(sd)
            base_models.append(net)

        weights = np.array(ckpt["ensemble_weights"])
        temperature = float(ckpt["temperature"])
        
        # Handle both ensemble and individual model cases
        if "individual" in model_type or ensemble_strategy == "Top-1 (Best Individual)":
            # For individual models, find the model with weight = 1.0
            selected_idx = np.argmax(weights)
            selected_model = base_models[selected_idx]
            print(f"Loading individual model: {model_cfgs[selected_idx]['display_name']}")
            # Wrap individual model to match ensemble interface
            wrapped_model = IndividualModelWrapper(selected_model)
            calibrated = TemperatureScaling(wrapped_model, temperature, device=self.device).to(self.device)
        else:
            # Traditional ensemble
            print(f"Loading ensemble with strategy: {ensemble_strategy}")
            ensemble = EnsembleModel(base_models, weights, device=self.device)
            calibrated = TemperatureScaling(ensemble, temperature, device=self.device).to(self.device)
        
        return calibrated

    # ---------------- Public API ----------------
    def predict(self, image: Image.Image) -> Dict:
        """Single-image prediction."""
        x = self.transform(image).unsqueeze(0).to(self.device)

        with torch.no_grad():
            logits = self.model(x, return_logits=True)
            probs  = F.softmax(logits, dim=1).cpu().numpy()[0]
            idx    = int(np.argmax(probs))

        pred_name = self.class_names[idx]
        return {
            "predicted_class" : pred_name,
            "predicted_index" : idx,
            "confidence"      : float(probs[idx]),
            "all_probabilities": {
                cls: float(p) for cls, p in zip(self.class_names, probs)
            },
            "description"     : self.class_descriptions[pred_name],
        }

    def predict_batch(self, images: List[Image.Image]) -> List[Dict]:
        return [self.predict(img) for img in images]


print("Osteoarthritis Classification Model class defined")


Osteoarthritis Classification Model class defined


In [3]:
# === LLM-Powered Treatment Planning System ===

class ClinicalTreatmentPlanner:
    """AI-powered treatment planning system for osteoarthritis."""
    def __init__(self, api_key: Optional[str] = None):
        self.llm_available = OPENAI_AVAILABLE and api_key is not None
        if self.llm_available:
            openai.api_key = api_key

    def generate_treatment_plan(self, classification_result: Dict, patient_data: Dict) -> Dict:
        if self.llm_available:
            try:
                return self._generate_llm_treatment_plan(classification_result, patient_data)
            except Exception as e:
                st.error(f"LLM Error: {e}")
        return self._generate_rule_based_treatment_plan(classification_result, patient_data)

    # ---------------- LLM path ----------------
    def _generate_llm_treatment_plan(self, classification_result: Dict, patient_data: Dict) -> Dict:
        prompt = self._construct_clinical_prompt(classification_result, patient_data)
        response = openai.ChatCompletion.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": self._get_system_prompt()},
                {"role": "user", "content": prompt},
            ],
            max_tokens=1500,
            temperature=0.3,
        )
        text = response.choices[0].message.content
        return self._parse_treatment_response(text)

    # ---------------- Rule-based path ----------------
    def _generate_rule_based_treatment_plan(self, classification_result: Dict, patient_data: Dict) -> Dict:
        severity = classification_result["predicted_index"]
        base_plan = self._severity_templates().get(severity, self._severity_templates()[2])
        personalized = self._personalize_treatment_plan(base_plan, patient_data)
        return {
            "severity_level"     : classification_result["predicted_class"],
            "confidence"         : classification_result["confidence"],
            "primary_approach"   : personalized["primary"],
            "medications"        : personalized["medications"],
            "non_pharmacological": personalized["non_pharmacological"],
            "lifestyle"          : personalized["lifestyle"],
            "surgical_options"   : personalized.get("surgical_options", []),
            "follow_up"          : self._get_follow_up_recommendations(severity),
            "red_flags"          : self._get_red_flags(),
            "references"         : self._get_clinical_references(),
        }

    # ---------------- Templates & utils ----------------
    def _severity_templates(self) -> Dict[int, Dict]:
        return {
            0: {
                "primary": "Preventive care and lifestyle modification",
                "medications": ["No medications needed"],
                "non_pharmacological": [
                    "Regular low-impact exercise", "Weight management", "Joint protection", "Annual monitoring"
                ],
                "lifestyle": [
                    "Healthy weight", "Regular activity", "Anti-inflammatory diet", "Adequate sleep"
                ],
            },
            1: {
                "primary": "Early intervention and lifestyle modification",
                "medications": ["Topical NSAIDs if symptomatic"],
                "non_pharmacological": [
                    "Physical therapy evaluation", "Low-impact aerobic exercise",
                    "Strength training", "Patient education on joint protection"
                ],
                "lifestyle": [
                    "Weight loss if BMI >25", "Anti-inflammatory diet", "Stress management", "Ergonomic setup"
                ],
            },
            2: {
                "primary": "Conservative management with symptom control",
                "medications": [
                    "Topical NSAIDs (first-line)", "Short-term oral NSAIDs", "Acetaminophen for pain"
                ],
                "non_pharmacological": [
                    "Structured exercise program", "Physical therapy", "Hot/cold therapy",
                    "Support devices", "Acupuncture"
                ],
                "lifestyle": [
                    "Mediterranean diet", "Weight management program", "Yoga or tai chi", "Sleep hygiene"
                ],
            },
            3: {
                "primary": "Multimodal pain management & function preservation",
                "medications": [
                    "Topical NSAIDs", "Oral NSAIDs (with gastroprotection)",
                    "Corticosteroid injections", "Consider hyaluronic acid"
                ],
                "non_pharmacological": [
                    "Comprehensive physiotherapy", "Occupational therapy",
                    "CBT for pain", "Assistive devices", "TENS unit"
                ],
                "lifestyle": [
                    "Structured weight loss", "Aquatic therapy", "Joint-friendly activities", "Stress reduction"
                ],
            },
            4: {
                "primary": "Advanced pain management & surgical evaluation",
                "medications": [
                    "Combination analgesics", "Intra-articular treatments",
                    "Consider duloxetine", "Topical capsaicin"
                ],
                "non_pharmacological": [
                    "Multidisciplinary pain management", "Pre-surgical physiotherapy",
                    "Psychological support", "Mobility aids", "Orthopedic referral"
                ],
                "surgical_options": [
                    "Total knee replacement evaluation", "Partial replacement", "Osteotomy (select cases)"
                ],
                "lifestyle": [
                    "Joint protection strategies", "Smoking cessation if applicable"
                ]
            },
        }

    def _personalize_treatment_plan(self, base_plan: Dict, patient_data: Dict) -> Dict:
        plan = {k: v[:] if isinstance(v, list) else v for k, v in base_plan.items()}  # shallow copy lists
        age          = patient_data.get("age", 50)
        gender       = patient_data.get("gender", "Other")
        comorbidities = [c.lower() for c in patient_data.get("comorbidities", [])]

        if age > 65:
            plan["medications"] = [m for m in plan["medications"] if "NSAIDs" not in m] + [
                "Avoid NSAIDs (↑ GI/CV risk)", "Acetaminophen preferred", "Low-dose topical NSAIDs"
            ]
            plan["non_pharmacological"].append("Fall prevention assessment")

        if gender.lower() == "female":
            plan["lifestyle"].append("Consider calcium & vitamin D")
            if age > 50:
                plan["lifestyle"].append("Bone health evaluation (post-menopausal)")

        if "diabetes" in comorbidities:
            plan["lifestyle"].append("Diabetic-friendly exercise program")
        if "cardiovascular" in comorbidities:
            plan["medications"] = [m for m in plan["medications"] if "NSAIDs" not in m]
            plan["medications"].append("Avoid NSAIDs due to CV risk")

        return plan

    def _get_follow_up_recommendations(self, severity: int) -> List[str]:
        follow_up = {
            0: ["Annual screening", "Monitor for symptoms"],
            1: ["6-month follow-up", "Monitor progression"],
            2: ["3-month follow-up", "Assess response"],
            3: ["6-week follow-up", "Pain/Function reassessment"],
            4: ["2-week follow-up", "Surgical consult", "Pain optimization"],
        }
        return follow_up.get(severity, follow_up[2])

    def _get_red_flags(self) -> List[str]:
        return [
            "Severe, uncontrolled pain",
            "Signs of infection (fever, redness, warmth)",
            "Significant functional decline",
            "Inability to bear weight",
            "Neurological symptoms",
            "Suspected fracture",
        ]

    def _get_clinical_references(self) -> List[str]:
        return [
            "ACR/AF 2019 Guideline (Hand, Hip, Knee OA)",
            "OARSI Guidelines (Non-surgical Management)",
            "NICE Guideline: Osteoarthritis in Over 16s (2022)",
            "Kolasinski et al. 2020 ACR Guideline",
            "McAlindon et al. 2014 OARSI Guidelines",
        ]

    def _construct_clinical_prompt(self, classification_result: Dict, patient_data: Dict) -> str:
        return f"""
Clinical Case Assessment:
- X-ray Classification: {classification_result['predicted_class']} ({classification_result['description']})
- Confidence: {classification_result['confidence']:.2%}

Patient Information:
- Age: {patient_data.get('age', 'Unknown')}
- Gender: {patient_data.get('gender', 'Unknown')}
- Symptoms: {', '.join(patient_data.get('symptoms', ['None reported']))}
- Treatment Expectations: {patient_data.get('expectations', 'Unknown')}
- Comorbidities: {', '.join(patient_data.get('comorbidities', ['None reported']))}

Please provide a comprehensive, evidence-based treatment plan following current clinical guidelines.
"""

    def _get_system_prompt(self) -> str:
        return (
            "You are an expert rheumatologist and orthopedic specialist. "
            "Provide evidence-based treatment recommendations for osteoarthritis "
            "following current ACR, OARSI, and NICE guidelines. Focus on holistic care."
        )

    def _parse_treatment_response(self, response: str) -> Dict:
        return {
            "llm_response": response,
            "generated_by": "LLM",
            "timestamp": datetime.datetime.now().isoformat(),
        }

print("Clinical Treatment Planner class defined")

Clinical Treatment Planner class defined


In [None]:
# === Streamlit Clinical UI Application ===

def main():
    """Main Streamlit application."""
    st.set_page_config(
        page_title="Osteoarthritis Clinical Decision Support",
        page_icon="🏥",
        layout="wide",
        initial_sidebar_state="expanded",
    )

    # CSS
    st.markdown(
        """
        <style>
        .main-header {
            background: linear-gradient(90deg, #1e3c72 0%, #2a5298 100%);
            padding: 1rem; border-radius: 10px; color: white; text-align: center; margin-bottom: 2rem;
        }
        .prediction-box {
            border: 2px solid #2a5298; border-radius: 10px; padding: 1rem; margin: 1rem 0; background-color: #f8f9fa;
        }
        .treatment-section {
            border-left: 4px solid #28a745; padding-left: 1rem; margin: 1rem 0;
        }
        .warning-box {
            background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 5px; padding: 1rem; margin: 1rem 0;
        }
        </style>
        """,
        unsafe_allow_html=True,
    )

    # Header
    st.markdown(
        """
        <div class="main-header">
            <h1>Osteoarthritis Clinical Decision Support System</h1>
            <p>AI-Powered Diagnosis and Evidence-Based Treatment Planning</p>
        </div>
        """,
        unsafe_allow_html=True,
    )

    # Init session state
    if "model" not in st.session_state:
        with st.spinner("Loading AI model..."):
            # Load best model from deployment (saved by notebook 04)
            default_ckpt = REPO_ROOT / "models" / "deployment" / "best_model_for_deployment.pth"
            try:
                st.session_state.model = OsteoarthritisClassificationModel(str(default_ckpt))
                if default_ckpt.exists():
                    # Load and display model info
                    ckpt = torch.load(default_ckpt, map_location="cpu", weights_only=False)
                    model_type = ckpt.get("model_type", "unknown")
                    deployment_strategy = ckpt.get("deployment_strategy", "unknown")
                    ensemble_strategy = ckpt.get("ensemble_strategy", "unknown")
                    
                    if "individual" in model_type:
                        st.success(f"Best individual model loaded: {ensemble_strategy}")
                        st.info(f"Model type: {model_type} | Strategy: {deployment_strategy}")
                    else:
                        st.success(f"Best ensemble model loaded: {ensemble_strategy}")
                        st.info(f"Model type: {model_type} | Strategy: {deployment_strategy}")
                else:
                    st.warning("No trained model found - using demo model instead")
                    st.info("Run notebook 04 (Multi_Class_Full_Training_Ensemble) to train your model")
            except Exception as e:
                st.error(f"Failed to load trained model: {e}")
                st.warning("Using default demo model instead")
                st.session_state.model = OsteoarthritisClassificationModel()

    if "treatment_planner" not in st.session_state:
        # Initialize with API key from secrets
        api_key = get_openai_key()
        st.session_state.treatment_planner = ClinicalTreatmentPlanner(api_key)
        if api_key:
            st.success("AI-powered treatment recommendations enabled!", icon="✅")
        else:
            st.info("ℹLLM integration not configured - using rule-based recommendations")

    # Sidebar
    st.sidebar.title("Navigation")
    page = st.sidebar.selectbox(
        "Choose Function",
        ["Single Patient Analysis", "Batch Processing", "Analytics Dashboard", "Settings"],
    )

    if page == "Single Patient Analysis":
        single_patient_interface()
    elif page == "Batch Processing":
        batch_processing_interface()
    elif page == "Analytics Dashboard":
        analytics_dashboard()
    elif page == "Settings":
        settings_interface()


def single_patient_interface():
    st.header("Single Patient Analysis")
    
    # Demo Patients Quick Load Section
    st.subheader("🏥 Demo Patients")
    demo_patients_dir = REPO_ROOT / "data" / "consensus" / "demo_patients"
    
    if demo_patients_dir.exists():
        # Load demo patient metadata
        metadata_file = demo_patients_dir / "demo_patients_metadata.json"
        demo_metadata = {}
        if metadata_file.exists():
            try:
                with open(metadata_file, 'r') as f:
                    demo_metadata = json.load(f)
            except:
                pass
        
        # Create demo patient selection
        demo_options = ["Select a demo patient..."]
        demo_files = {}
        
        for severity_dir in sorted(demo_patients_dir.iterdir()):
            if severity_dir.is_dir():
                for img_file in severity_dir.glob("*.png"):
                    if img_file.name in demo_metadata:
                        patient_info = demo_metadata[img_file.name]
                        # Don't reveal the actual severity class in the selection
                        display_name = f"{patient_info['name']} ({patient_info['age']}{patient_info['gender'][0]}) - {patient_info['occupation']}"
                        demo_options.append(display_name)
                        # Store actual severity from folder name for comparison
                        actual_severity = severity_dir.name[1:]  # Remove the leading number
                        demo_files[display_name] = (img_file, patient_info, actual_severity)
        
        selected_demo = st.selectbox("Quick Load Demo Patient", demo_options)
        
        if selected_demo != "Select a demo patient..." and selected_demo in demo_files:
            demo_file, patient_info, actual_severity = demo_files[selected_demo]
            
            # Load and display demo patient (without revealing actual classification)
            st.info(f"Loaded: **{patient_info['name']}** - {patient_info['occupation']}")
            
            # Auto-populate patient information from metadata
            col1, col2, col3 = st.columns(3)
            with col1:
                patient_age = st.number_input("Age", min_value=18, max_value=120, value=patient_info['age'])
                patient_gender = st.selectbox("Gender", ["Male", "Female", "Other"], 
                                             index=["Male", "Female", "Other"].index(patient_info['gender']))
            with col2:
                symptom_options = [
                    "Joint pain", "Stiffness", "Swelling", "Reduced mobility",
                    "Grinding sensation", "Joint instability", "Morning stiffness"
                ]
                # Smart mapping from patient metadata to dropdown options
                patient_symptoms = patient_info.get('symptoms', [])
                symptom_default = []
                for symptom in patient_symptoms:
                    symptom_lower = symptom.lower()
                    if any(word in symptom_lower for word in ['pain', 'ache', 'hurt']):
                        if "Joint pain" not in symptom_default:
                            symptom_default.append("Joint pain")
                    if any(word in symptom_lower for word in ['stiff', 'morning stiff']):
                        if "stiffness" in symptom_lower and "morning" in symptom_lower:
                            if "Morning stiffness" not in symptom_default:
                                symptom_default.append("Morning stiffness")
                        elif "Stiffness" not in symptom_default:
                            symptom_default.append("Stiffness")
                    if any(word in symptom_lower for word in ['swell', 'inflam']):
                        if "Swelling" not in symptom_default:
                            symptom_default.append("Swelling")
                    if any(word in symptom_lower for word in ['mobility', 'movement', 'difficulty', 'limitation']):
                        if "Reduced mobility" not in symptom_default:
                            symptom_default.append("Reduced mobility")
                
                symptoms = st.multiselect("Current Symptoms", symptom_options, default=symptom_default)
            with col3:
                expectation_options = [
                    "Pain relief", "Improved mobility", "Prevent progression", "Return to activities", "Surgery avoidance"
                ]
                # Smart mapping for treatment expectations
                patient_expectation = patient_info.get('treatment_expectations', '').lower()
                expectation_index = 0  # Default to "Pain relief"
                
                if any(word in patient_expectation for word in ['prevent', 'progression', 'early']):
                    expectation_index = 2  # "Prevent progression"
                elif any(word in patient_expectation for word in ['mobility', 'active', 'movement', 'function']):
                    expectation_index = 1  # "Improved mobility"  
                elif any(word in patient_expectation for word in ['pain', 'relief', 'manage']):
                    expectation_index = 0  # "Pain relief"
                elif any(word in patient_expectation for word in ['work', 'activity', 'return', 'continue']):
                    expectation_index = 3  # "Return to activities"
                elif any(word in patient_expectation for word in ['avoid', 'surgery', 'non-surgical']):
                    expectation_index = 4  # "Surgery avoidance"
                
                expectations = st.selectbox("Treatment Expectations", expectation_options, index=expectation_index)
                
                comorbidity_options = [
                    "Diabetes", "Cardiovascular disease", "Hypertension", "Kidney disease"
                ]
                # Smart mapping for comorbidities
                patient_comorbidities = patient_info.get('comorbidities', [])
                comorbidity_default = []
                for condition in patient_comorbidities:
                    condition_lower = condition.lower()
                    if any(word in condition_lower for word in ['diabetes', 'diabetic']):
                        if "Diabetes" not in comorbidity_default:
                            comorbidity_default.append("Diabetes")
                    if any(word in condition_lower for word in ['hypertension', 'high blood pressure']):
                        if "Hypertension" not in comorbidity_default:
                            comorbidity_default.append("Hypertension")
                    if any(word in condition_lower for word in ['cardiovascular', 'heart', 'cardiac']):
                        if "Cardiovascular disease" not in comorbidity_default:
                            comorbidity_default.append("Cardiovascular disease")
                    if any(word in condition_lower for word in ['kidney', 'renal']):
                        if "Kidney disease" not in comorbidity_default:
                            comorbidity_default.append("Kidney disease")
                
                comorbidities = st.multiselect("Comorbidities", comorbidity_options, default=comorbidity_default)
            
            # Auto-load the demo image
            demo_image = Image.open(demo_file)
            
            c1, c2 = st.columns([1, 2])
            with c1:
                st.image(demo_image, caption=f"Demo Patient: {patient_info['name']}", use_container_width=True)
                
                # Display patient context
                with st.expander("Patient Context"):
                    st.write(f"**Occupation:** {patient_info['occupation']}")
                    st.write(f"**BMI:** {patient_info.get('bmi', 'N/A')}")
                    st.write(f"**Activity Level:** {patient_info.get('activity_level', 'N/A')}")
                    if patient_info.get('medical_history'):
                        st.write(f"**Medical History:** {', '.join(patient_info['medical_history'])}")
                    if patient_info.get('medications'):
                        st.write(f"**Current Medications:** {', '.join(patient_info['medications'])}")

            with c2:
                with st.spinner("Analyzing demo patient X-ray (blind prediction)..."):
                    # Make blind prediction without knowing actual classification
                    prediction = st.session_state.model.predict(demo_image)

                st.markdown('<div class="prediction-box"><h3>🔍 AI Analysis Results</h3></div>', unsafe_allow_html=True)
                
                # Show AI prediction
                ai_severity = prediction["predicted_class"]
                confidence = prediction["confidence"]
                st.metric("AI Predicted Severity", ai_severity, delta=f"Confidence: {confidence:.1%}")
                st.write(f"**AI Assessment:** {prediction['description']}")
                
                # Show comparison with actual classification
                st.markdown("---")
                st.subheader("Prediction Validation")
                
                col_pred, col_actual, col_match = st.columns(3)
                with col_pred:
                    st.metric("AI Prediction", ai_severity)
                with col_actual:
                    st.metric("Actual Classification", actual_severity)
                with col_match:
                    is_correct = ai_severity == actual_severity
                    match_status = "✅ Correct" if is_correct else "❌ Incorrect"
                    if is_correct:
                        st.success(f"**{match_status}**")
                    else:
                        st.error(f"**{match_status}**")

                st.subheader("Probability Distribution")
                prob_data = prediction["all_probabilities"]
                fig = px.bar(x=list(prob_data.keys()), y=list(prob_data.values()),
                            title="AI Classification Probabilities",
                            labels={"x": "Severity Level", "y": "Probability"})
                
                # Highlight the actual classification in the chart
                colors = ['red' if severity == actual_severity else 'lightblue' for severity in prob_data.keys()]
                fig.update_traces(marker_color=colors)
                fig.update_layout(showlegend=False)
                st.plotly_chart(fig, use_container_width=True)
                
                # Add annotation about the highlighted bar
                st.caption(f"🔴 Red bar shows actual classification ({actual_severity})")

            if st.button("Generate Treatment Plan Based on AI Prediction", type="primary"):
                patient_data = {
                    "age": patient_age, "gender": patient_gender, "symptoms": symptoms,
                    "expectations": expectations, "comorbidities": [c.lower() for c in comorbidities],
                    "occupation": patient_info['occupation'],
                    "activity_level": patient_info.get('activity_level', 'Moderate'),
                    "bmi": patient_info.get('bmi', 25.0)
                }
                with st.spinner("Generating treatment plan based on AI prediction..."):
                    plan = st.session_state.treatment_planner.generate_treatment_plan(prediction, patient_data)
                
                # Enhanced treatment plan display with validation context
                display_treatment_plan_with_validation(plan, prediction, actual_severity, patient_info)
            
            st.markdown("---")
    
    # Regular Patient Upload Section
    st.subheader("Upload Patient X-ray")
    
    # Patient info
    st.subheader("Patient Information")
    col1, col2, col3 = st.columns(3)

    with col1:
        patient_age = st.number_input("Age", min_value=18, max_value=120, value=50, key="regular_age")
        patient_gender = st.selectbox("Gender", ["Male", "Female", "Other"], key="regular_gender")
    with col2:
        symptom_options = [
            "Joint pain", "Stiffness", "Swelling", "Reduced mobility",
            "Grinding sensation", "Joint instability", "Morning stiffness",
        ]
        symptoms = st.multiselect("Current Symptoms", symptom_options, key="regular_symptoms")
    with col3:
        expectation_options = [
            "Pain relief", "Improved mobility", "Prevent progression", "Return to activities", "Surgery avoidance"
        ]
        expectations = st.selectbox("Treatment Expectations", expectation_options, key="regular_expectations")
        comorbidities = st.multiselect(
            "Comorbidities",
            ["Diabetes", "Cardiovascular disease", "Hypertension", "Kidney disease", "Gastrointestinal disease", "Depression"],
            key="regular_comorbidities"
        )

    # Image upload
    uploaded_file = st.file_uploader("Upload knee X-ray image", type=["jpg", "jpeg", "png"])

    if uploaded_file:
        # Minimize the upload section once file is uploaded
        st.success(f"✅ File uploaded: {uploaded_file.name}")
        with st.expander("📁 Upload Different Image", expanded=False):
            st.info("To upload a different image, use the file uploader above and select a new file.")
        image = Image.open(uploaded_file)
        c1, c2 = st.columns([1, 2])

        with c1:
            st.image(image, caption="Uploaded X-ray", use_container_width=True)

        with c2:
            with st.spinner("Analyzing X-ray..."):
                prediction = st.session_state.model.predict(image)

            st.markdown('<div class="prediction-box"><h3>AI Analysis Results</h3></div>', unsafe_allow_html=True)

            severity   = prediction["predicted_class"]
            confidence = prediction["confidence"]
            st.metric("Osteoarthritis Severity", severity, delta=f"Confidence: {confidence:.1%}")

            st.write(f"**Description:** {prediction['description']}")

            st.subheader("Probability Distribution")
            prob_data = prediction["all_probabilities"]
            fig = px.bar(
                x=list(prob_data.keys()),
                y=list(prob_data.values()),
                title="Classification Probabilities",
                labels={"x": "Severity Level", "y": "Probability"},
            )
            fig.update_layout(showlegend=False)
            st.plotly_chart(fig, use_container_width=True)

        if st.button("Generate Treatment Plan", type="primary"):
            patient_data = {
                "age": patient_age,
                "gender": patient_gender,
                "symptoms": symptoms,
                "expectations": expectations,
                "comorbidities": [c.lower() for c in comorbidities],
            }
            with st.spinner("Generating personalized treatment plan..."):
                plan = st.session_state.treatment_planner.generate_treatment_plan(prediction, patient_data)
            display_treatment_plan(plan)


def display_treatment_plan(treatment_plan: Dict):
    st.markdown('<div class="treatment-section"><h2>Personalized Treatment Plan</h2></div>', unsafe_allow_html=True)

    # Check if this is an LLM-generated response or rule-based
    if "llm_response" in treatment_plan:
        # LLM-generated treatment plan
        st.subheader("🤖 AI-Generated Treatment Plan")
        st.write(treatment_plan["llm_response"])
        
        # Show additional metadata
        st.caption(f"Generated by: {treatment_plan.get('generated_by', 'AI')} at {treatment_plan.get('timestamp', 'Unknown time')}")
        
        # Add note about LLM usage
        st.info("💡 This treatment plan was generated using advanced AI language models. Please review carefully and consult clinical guidelines.")
        
    else:
        # Rule-based treatment plan
        st.subheader("Primary Treatment Approach")
        st.write(treatment_plan["primary_approach"])

        st.subheader("Pharmacological Interventions")
        for med in treatment_plan["medications"]:
            st.write(f"• {med}")

        st.subheader("Non-Pharmacological Interventions")
        for item in treatment_plan["non_pharmacological"]:
            st.write(f"• {item}")

        st.subheader("Lifestyle Modifications")
        for item in treatment_plan["lifestyle"]:
            st.write(f"• {item}")

        if treatment_plan.get("surgical_options"):
            st.subheader("Surgical Considerations")
            for opt in treatment_plan["surgical_options"]:
                st.write(f"• {opt}")

        st.subheader("Follow-up Recommendations")
        for r in treatment_plan["follow_up"]:
            st.write(f"• {r}")

        st.markdown('<div class="warning-box"><h4>⚠️ Clinical Red Flags - Seek Immediate Attention</h4></div>', unsafe_allow_html=True)
        for flag in treatment_plan["red_flags"]:
            st.write(f"🚨 {flag}")

        # Clinical references (only for rule-based plans)
        if "references" in treatment_plan:
            with st.expander("Clinical References"):
                for ref in treatment_plan["references"]:
                    st.write(f"• {ref}")

    st.subheader("Export Options")
    c1, c2, c3 = st.columns(3)
    with c1:
        if st.button("Copy to Clipboard"):
            st.success("Treatment plan copied!")
    with c2:
        if st.button("Generate PDF Report"):
            st.success("PDF report generated!")
    with c3:
        if st.button("Save to Patient Record"):
            st.success("Saved to patient record!")


def display_treatment_plan_with_validation(treatment_plan: Dict, prediction: Dict, actual_severity: str, patient_info: Dict):
    st.markdown('<div class="treatment-section"><h2>Personalized Treatment Plan</h2></div>', unsafe_allow_html=True)
    
    # Show validation status at the top
    is_correct = prediction["predicted_class"] == actual_severity
    if is_correct:
        st.success(f"✅ **AI Classification Correct** - Treatment plan based on accurate {prediction['predicted_class']} diagnosis")
    else:
        st.error(f"⚠️ **AI Misclassification** - Treatment plan based on AI prediction ({prediction['predicted_class']}) but actual severity is {actual_severity}")
        st.warning("**Clinical Note:** Consider this classification discrepancy when evaluating treatment recommendations")
    
    # Check if this is an LLM-generated response or rule-based
    if "llm_response" in treatment_plan:
        # LLM-generated treatment plan
        st.subheader("🤖 AI-Generated Treatment Plan")
        st.write(treatment_plan["llm_response"])
        
        # Show additional metadata
        st.caption(f"Generated by: {treatment_plan.get('generated_by', 'AI')} at {treatment_plan.get('timestamp', 'Unknown time')}")
        
        # Add note about LLM usage
        st.info("💡 This treatment plan was generated using advanced AI language models. Please review carefully and consult clinical guidelines.")
        
    else:
        # Rule-based treatment plan
        st.subheader("Primary Treatment Approach")
        st.write(treatment_plan["primary_approach"])
        
        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Pharmacological Interventions")
            for med in treatment_plan["medications"]:
                st.write(f"• {med}")
            
            st.subheader("Non-Pharmacological Interventions")
            for item in treatment_plan["non_pharmacological"]:
                st.write(f"• {item}")
        
        with col2:
            st.subheader("Lifestyle Modifications")
            for item in treatment_plan["lifestyle"]:
                st.write(f"• {item}")
                
            if treatment_plan.get("surgical_options"):
                st.subheader("Surgical Considerations")
                for opt in treatment_plan["surgical_options"]:
                    st.write(f"• {opt}")

        st.subheader("Follow-up Recommendations")
        for r in treatment_plan["follow_up"]:
            st.write(f"• {r}")

        st.markdown('<div class="warning-box"><h4>⚠️ Clinical Red Flags - Seek Immediate Attention</h4></div>', unsafe_allow_html=True)
        for flag in treatment_plan["red_flags"]:
            st.write(f"🚨 {flag}")

    # Clinical validation section
    st.markdown("---")
    st.subheader("Clinical Validation & Quality Assurance")
    
    col_pred, col_actual, col_match = st.columns(3)
    with col_pred:
        st.metric("AI Prediction", prediction["predicted_class"], delta=f"{prediction['confidence']:.1%} confidence")
    with col_actual:
        st.metric("Actual Classification", actual_severity)
    with col_match:
        match_status = "✅ Correct" if is_correct else "❌ Incorrect"
        if is_correct:
            st.success(f"**{match_status}**")
        else:
            st.error(f"**{match_status}**")
    
    # Detailed clinical analysis
    with st.expander("Detailed Clinical Analysis"):
        st.write(f"**Patient:** {patient_info['name']} ({patient_info['age']}{patient_info['gender'][0]})")
        st.write(f"**Occupation:** {patient_info['occupation']}")
        st.write(f"**AI Prediction:** {prediction['predicted_class']} (Confidence: {prediction['confidence']:.1%})")
        st.write(f"**Ground Truth:** {actual_severity}")
        st.write(f"**Treatment Based On:** AI prediction ({prediction['predicted_class']})")
        
        if is_correct:
            st.success("✅ **Validation Passed** - AI prediction matches expert annotation")
            st.write("**Clinical Implications:**")
            st.write("- Treatment recommendations are based on accurate AI diagnosis")
            st.write("- High confidence in clinical decision support")
            st.write("- No additional validation required")
        else:
            st.error("❌ **Validation Failed** - Misclassification detected")
            st.write("**Clinical Implications:**")
            st.write(f"- AI predicted **{prediction['predicted_class']}** but actual severity is **{actual_severity}**")
            st.write("- Treatment plan may not be optimal for actual condition")
            st.write("- Recommend additional clinical assessment")
            st.write("- Consider expert radiologist review")
            
            # Show what the treatment would be for the actual classification
            st.write("**Alternative Treatment Consideration:**")
            st.write(f"If patient had **{actual_severity}** severity (actual), treatment approach would differ")

    # Clinical references (only for rule-based plans)
    if "references" in treatment_plan:
        with st.expander("📚 Clinical References"):
            for ref in treatment_plan["references"]:
                st.write(f"• {ref}")

    # Export options with validation context
    st.subheader("Export Options")
    c1, c2, c3 = st.columns(3)
    with c1:
        if st.button("Copy to Clipboard"):
            st.success("Treatment plan copied!")
    with c2:
        if st.button("Generate Clinical Report"):
            # Generate a comprehensive report including validation
            report_data = {
                "patient_name": patient_info['name'],
                "ai_prediction": prediction['predicted_class'],
                "actual_classification": actual_severity,
                "prediction_correct": is_correct,
                "confidence": prediction['confidence'],
                "treatment_plan": treatment_plan
            }
            st.success("Clinical report with validation generated!")
    with c3:
        if st.button("Save to Patient Record"):
            st.success("Saved to patient record with validation notes!")


def batch_processing_interface():
    st.header("Batch Processing")
    st.info("Upload multiple X-ray images for batch analysis")

    uploaded_files = st.file_uploader("Upload multiple X-ray images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)

    if uploaded_files:
        st.write(f"Uploaded {len(uploaded_files)} images")
        if st.button("Process Batch"):
            results = []
            progress = st.progress(0)
            for i, f in enumerate(uploaded_files):
                img = Image.open(f)
                pred = st.session_state.model.predict(img)
                results.append({"filename": f.name, "severity": pred["predicted_class"], "confidence": pred["confidence"]})
                progress.progress((i + 1) / len(uploaded_files))

            df = pd.DataFrame(results)
            st.dataframe(df)

            st.subheader("Batch Summary")
            counts = df["severity"].value_counts()
            fig = px.pie(values=counts.values, names=counts.index, title="Severity Distribution")
            st.plotly_chart(fig)


def analytics_dashboard():
    st.header("Analytics Dashboard")
    st.info("Coming soon: Patient analytics and outcome tracking")


def settings_interface():
    st.header("Settings")

    st.subheader("AI Model Configuration")
    
    # Display detailed model information
    default_ckpt = REPO_ROOT / "models" / "deployment" / "best_model_for_deployment.pth"
    if default_ckpt.exists():
        try:
            ckpt = torch.load(default_ckpt, map_location="cpu", weights_only=False)
            
            st.success("**Trained Model Information**")
            col1, col2 = st.columns(2)
            
            with col1:
                st.metric("Model Type", ckpt.get("model_type", "unknown"))
                st.metric("Deployment Strategy", ckpt.get("deployment_strategy", "unknown"))
                st.metric("Ensemble Strategy", ckpt.get("ensemble_strategy", "unknown"))
                
            with col2:
                test_results = ckpt.get("test_results", {})
                ensemble_result = test_results.get("ensemble_result", {})
                st.metric("Test Accuracy", f"{ensemble_result.get('accuracy', 0):.2f}%")
                st.metric("Test F1 Score", f"{ensemble_result.get('f1', 0):.4f}")
                st.metric("Temperature", f"{ckpt.get('temperature', 1.0):.4f}")
            
            # Show model weights if it's an ensemble
            if ckpt.get("ensemble_strategy") != "Top-1 (Best Individual)":
                with st.expander("Ensemble Model Weights"):
                    weights = ckpt.get("ensemble_weights", [])
                    model_configs = ckpt.get("model_configs", [])
                    if weights and model_configs:
                        for cfg, weight in zip(model_configs, weights):
                            st.write(f"• **{cfg.get('display_name', 'Unknown')}**: {weight:.4f}")
            else:
                with st.expander("Individual Model Details"):
                    weights = np.array(ckpt.get("ensemble_weights", []))
                    model_configs = ckpt.get("model_configs", [])
                    if len(weights) > 0 and len(model_configs) > 0:
                        selected_idx = np.argmax(weights)
                        selected_model = model_configs[selected_idx]
                        st.write(f"**Selected Model**: {selected_model.get('display_name', 'Unknown')}")
                        st.write(f"**Architecture**: {selected_model.get('model_name', 'Unknown')}")
                        st.write(f"**Confidence Weight**: {weights[selected_idx]:.4f}")
                        
        except Exception as e:
            st.error(f"Error loading model info: {e}")
    else:
        st.warning("No trained model found - using demo model")
        st.info("Run notebook 04 to train your model")
    
    _ = st.slider("Confidence Threshold (UI only)", 0.0, 1.0, 0.7, 0.05)

    st.subheader("LLM Integration")
    
    # Check current API key status
    current_api_key = get_openai_key()
    if current_api_key:
        st.success("OpenAI API key configured via secrets.toml")
        st.info("AI-powered treatment recommendations are enabled")
        
        # Show API key status (masked)
        masked_key = current_api_key[:8] + "..." + current_api_key[-4:] if len(current_api_key) > 12 else "***"
        st.code(f"API Key: {masked_key}")
        
        if st.button("Refresh LLM Integration"):
            st.session_state.treatment_planner = ClinicalTreatmentPlanner(current_api_key)
            st.success("LLM integration refreshed!")
    else:
        st.warning("OpenAI API key not found")
        st.info("Add your API key to `.streamlit/secrets.toml` to enable AI-powered treatment recommendations")
        
        # Manual override option
        with st.expander("Manual API Key Override"):
            api_key = st.text_input("Temporary OpenAI API Key", type="password", help="Enter your OpenAI API key")
            if api_key:
                st.session_state.treatment_planner = ClinicalTreatmentPlanner(api_key)
                st.success("LLM integration enabled for this session!")


# Run app
def _running_inside_streamlit() -> bool:
    """Return True when launched via `streamlit run`, False in notebooks/REPL."""
    try:
        from streamlit import runtime  # Streamlit ≥1.25
        return runtime.exists()
    except Exception:
        return bool(os.environ.get("STREAMLIT_SERVER_ENABLED"))

if __name__ == "__main__":
    if _running_inside_streamlit():
        main()
    else:
        print("This notebook contains the clinical application code.")
        print("To launch the app, use one of these methods:")
        print()
        print("  • Double-click: start_clinical_app.bat")
        print("  • Command line: streamlit run clinical_app_standalone.py")
        print("  • Notebook: streamlit run notebooks/05_Deployment_and_Clinical_UI.ipynb")

print("Clinical UI Application defined")


This notebook contains the clinical application code.
To launch the app, use one of these methods:

  • Double-click: start_clinical_app.bat
  • Command line: streamlit run clinical_app_standalone.py
  • Notebook: streamlit run notebooks/05_Deployment_and_Clinical_UI.ipynb
Clinical UI Application defined


# **HOW TO LAUNCH THE CLINICAL APPLICATION**

## **Simple Click-to-Launch Options:**

### **EASIEST: Double-Click to Launch**
1. **Double-click:** `start_clinical_app.bat` (in main folder)
2. **Wait 5 seconds** for app to start
3. **Browser opens automatically** to the clinical interface

### **Alternative: Command Line**
```bash
# From main project folder:
streamlit run clinical_app_standalone.py
```

### **Alternative: Run This Notebook**  
```bash
# From main project folder:
streamlit run notebooks/05_Deployment_and_Clinical_UI.ipynb
```

---

## **What You'll Get:**
- **Professional clinical web interface** at `http://localhost:8501`
- **AI-powered X-ray analysis** with confidence scores
- **Evidence-based treatment recommendations**
- **Batch processing** for multiple patients
- **Clinical dashboard** and analytics

---

**NOTE:** Do not execute notebook cells directly. Use the launch methods above for the full application experience.
