In [5]:
!pip install git+https://github.com/jacobgil/pytorch-grad-cam.git -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for grad-cam (pyproject.toml) ... [?25l[?25hdone


In [44]:
import gradio as gr
import torch
import torch.nn as nn
import timm
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os

def get_disease_recommendations(predicted_class, confidence):
    """Provide evidence-based recommendations based on the predicted oral disease and confidence level, with expanded details from medical sources."""

    if predicted_class == "Calculus":
        if confidence >= 70:
            return """
## 🦷 **Calculus (Tartar) Detected - High Confidence**

**Overview:**
- Calculus, also known as tartar, is hardened plaque that forms on teeth when plaque mineralizes with calcium from saliva. It adheres strongly to tooth surfaces and can only be removed professionally. This buildup increases the risk of gingivitis and periodontitis by harboring bacteria that irritate gums.

**Symptoms and Risks:**
- Visible yellowish or brownish deposits along the gumline.
- Bad breath (halitosis), gum inflammation, and potential bleeding during brushing.
- If untreated, it can lead to gum recession, bone loss, and tooth mobility.

**Immediate Actions:**
- 🗓️ Schedule a dental cleaning (scaling) within 2-4 weeks as per ADA recommendations to remove supra- and sub-gingival calculus.
- 🪥 Enhance brushing technique: Use a soft-bristled toothbrush at a 45° angle to gums, brushing twice daily for two minutes with fluoride toothpaste.
- 🦷 Floss daily to disrupt plaque between teeth and under the gumline.

**Professional Care:**
- Scaling and root planing by a dental hygienist to remove calculus and smooth root surfaces.
- Regular 6-month cleanings for prevention, or more frequent if high risk (e.g., every 3-4 months for periodontal patients).
- Assessment for periodontal disease if inflammation or pockets are present.

**Home Care and Prevention:**
- Incorporate antiseptic mouthwash (e.g., chlorhexidine if prescribed) to reduce bacterial load.
- Use tartar-control toothpaste containing pyrophosphates or zinc citrate to inhibit mineralization.
- Reduce intake of staining agents like coffee, tea, and tobacco.
- Consider an electric toothbrush, which can remove up to 22% more plaque than manual brushing.
- Limit sugary foods to prevent further plaque accumulation.
            """
        else:
            return """
## ⚠️ **Possible Calculus - Moderate Confidence**

**Recommended Actions:**
- 🏥 Schedule dental check-up within 4-6 weeks for evaluation and potential cleaning.
- 🪥 Improve oral hygiene: Brush twice daily and floss to prevent progression.
- 📋 Monitor for signs like bad breath or gum bleeding.
            """

    elif predicted_class == "Caries":
        if confidence >= 70:
            return """
## 🔴 **Dental Caries (Cavities) Detected - High Confidence**

**Overview:**
- Dental caries is the destruction of tooth enamel and dentin by acid-producing bacteria, often from dietary sugars. It's the most common chronic disease worldwide, affecting billions.

**Symptoms and Risks:**
- Early signs include white spots on enamel; advanced stages cause pain, sensitivity, visible holes, or infection (abscess).
- Risks include poor oral hygiene, high-sugar diet, dry mouth, and fluoride deficiency. Untreated, it can lead to tooth loss or systemic infections.

**Immediate Actions:**
- 🚨 Schedule dental visit within 1-2 weeks for assessment and treatment.
- 🦷 Avoid sugary foods/drinks and brush with fluoride toothpaste after meals to promote remineralization.
- Clean between teeth daily with floss or interdental cleaners.

**Treatment Options:**
- Nonrestorative: Fluoride varnish or silver diamine fluoride to arrest early lesions.
- Restorative: Dental fillings for cavitated lesions; root canal or crown if pulp is involved.
- For children under 5 at high risk, apply fluoride varnish every 3-6 months.

**Pain Management:**
- Use over-the-counter analgesics (e.g., ibuprofen) for discomfort.
- Avoid hot/cold/sweet stimuli on affected teeth.
- Desensitizing toothpaste for mild sensitivity.

**Prevention:**
- Apply dental sealants on molars to protect pits and fissures.
- Limit sugar frequency and ensure fluoridated water intake (ADA recommendation).
- Caries risk assessment to tailor preventive strategies like professional fluoride treatments.
            """
        else:
            return """
## ⚠️ **Possible Dental Caries - Moderate Confidence**

**Recommended Actions:**
- 🏥 Schedule dental examination within 2-3 weeks.
- 🦷 Monitor for sensitivity, pain, or visible spots.
- 🍭 Reduce sugar intake and enhance fluoride use.
            """

    elif predicted_class == "Gingivitis":
        if confidence >= 70:
            return """
## 🔴 **Gingivitis Detected - High Confidence**

**Overview:**
- Gingivitis is reversible inflammation of the gums caused by plaque buildup, affecting over 50% of adults. It's the early stage of periodontal disease.

**Symptoms and Risks:**
- Red, swollen gums that bleed easily during brushing or flossing.
- Risks include poor hygiene, smoking, diabetes, and certain medications. Progression to periodontitis can cause bone loss.

**Immediate Actions:**
- 🗓️ Schedule dental visit within 1-2 weeks for professional cleaning.
- 🩸 Note bleeding and improve brushing/flossing.
- Brush twice daily with a soft brush for two minutes.

**Professional Care:**
- Scaling to remove plaque and calculus (ADA guideline for chronic periodontitis initial treatment).
- Antimicrobial rinse (e.g., chlorhexidine) for adjunct therapy.
- Oral hygiene education and possible 3-month recalls for high-risk patients.

**Home Care:**
- Floss correctly, curving around each tooth base.
- Use interdental brushes or water flosser for better access.
- Replace toothbrush every 3 months.

**Lifestyle:**
- 🚭 Cease smoking, a major risk factor for progression.
- 🥗 Increase vitamin C intake for gum health.
- 💧 Maintain hydration to support saliva flow.
            """
        else:
            return """
## ⚠️ **Possible Gingivitis - Moderate Confidence**

**Recommended Actions:**
- 🏥 Schedule dental evaluation within 2-3 weeks.
- 👀 Monitor for bleeding, redness, or swelling.
- 🪥 Enhance oral hygiene routine.
            """

    elif predicted_class == "Mouth Ulcer":
        if confidence >= 70:
            return """
## 🔴 **Mouth Ulcer Detected - High Confidence**

**Overview:**
- Mouth ulcers, or canker sores, are painful lesions on oral mucosa, often aphthous in nature. They typically heal in 1-2 weeks but can recur.

**Symptoms and Risks:**
- Round, white/yellow sores with red borders; pain during eating/speaking.
- Triggers include stress, trauma, acidic foods, or deficiencies (e.g., vitamin B12). Recurrent cases may link to systemic conditions.

**Immediate Care:**
- 💊 Apply topical anesthetics (e.g., benzocaine) or protective pastes for pain relief.
- 🧴 Rinse with salt water or baking soda solution multiple times daily.
- 🍽️ Avoid spicy, acidic, or abrasive foods to prevent irritation.
- 🧊 Dissolve ice chips on the sore for temporary numbness.

**Seek Immediate Care If:**
- 🚨 Ulcer persists beyond 2 weeks or worsens.
- 🚨 Accompanied by fever, swollen lymph nodes, or difficulty swallowing.
- 🚨 Ulcers are large (>1cm), multiple, or recurrent.

**Healing Support:**
- Maintain gentle oral hygiene with soft-bristled toothbrush.
- Avoid tobacco and alcohol, which delay healing.
- Consider over-the-counter treatments like honey or aloe vera for soothing.

**Medical Evaluation:**
- For recurrent ulcers, investigate underlying causes (e.g., Behcet's disease).
- Prescription options like amlexanox paste if needed.
            """
        else:
            return """
## ⚠️ **Possible Mouth Ulcer - Moderate Confidence**

**Recommended Actions:**
- 👀 Monitor healing over 7-10 days.
- 🏥 Consult dentist if no improvement in 2 weeks or symptoms worsen.
- 🍎 Avoid irritating foods and maintain gentle hygiene.
            """

    elif predicted_class == "Oral Cancer":
        if confidence >= 70:
            return """
## 🚨 **Oral Cancer Suspected - High Confidence**

**URGENT ACTION REQUIRED**

**Overview:**
- Oral cancer includes malignancies of the mouth and oropharynx, often linked to tobacco, alcohol, and HPV. Early detection improves 5-year survival from 50% to over 80%.

**Symptoms and Risks:**
- Persistent sores, lumps, white/red patches, difficulty swallowing, or numbness.
- High risks: Smoking, heavy alcohol use, HPV infection, sun exposure (lips), and family history.

**Immediate Actions:**
- 🚨 Contact dentist or oral surgeon within 24-48 hours for evaluation.
- 🚨 Request urgent biopsy for definitive diagnosis.
- 📋 Document symptoms, changes, and medical history.

**Do Not:**
- ❌ Delay evaluation, even if painless.
- ❌ Self-treat or ignore lesions.

**Diagnostic and Treatment Steps:**
- 🔍 Comprehensive oral exam and biopsy.
- 📊 Imaging (CT/MRI) and staging (stages 0-IV).
- Treatment: Surgery for early stages (I-II), possibly with radiation/chemo; advanced stages may include targeted therapy or immunotherapy.

**Prevention:**
- Avoid tobacco/alcohol; get HPV vaccine.
- Limit UV exposure; maintain healthy weight.
            """
        else:
            return """
## ⚠️ **Possible Oral Cancer Concern - Moderate Confidence**

**Recommended Actions:**
- 🏥 Schedule specialist visit within 1 week.
- 👀 Monitor and document lesion changes.
- 📋 Note associated symptoms like pain or swelling.
            """

    elif predicted_class == "Oral Lichen Planus":
        if confidence >= 70:
            return """
## 🔴 **Oral Lichen Planus Detected - High Confidence**

**Overview:**
- Oral lichen planus (OLP) is a chronic autoimmune condition affecting oral mucosa, with a 1-2% prevalence and small (1%) risk of malignant transformation.

**Symptoms and Risks:**
- White lacy patches, erosions, or ulcers; pain or burning sensation.
- Triggers: Stress, medications, or dental materials. Risks include hepatitis C association and potential progression to cancer.

**Medical Management:**
- 🏥 Consult oral medicine specialist for biopsy-confirmed diagnosis.
- 💊 Topical corticosteroids (e.g., clobetasol 0.05%) for erosive forms; systemic if severe.
- 📅 Follow-up every 6-12 months to monitor for changes.

**Symptom Relief:**
- Avoid spicy, acidic, or rough foods to minimize irritation.
- Use soft toothbrush and mild toothpaste.
- Stress management (e.g., relaxation techniques).

**Monitoring:**
- 🔍 Regular exams for malignant transformation signs.
- 🚨 Report non-healing ulcers or new lesions immediately.

**Lifestyle:**
- 🚭 Quit tobacco and limit alcohol.
- 🥗 Balanced diet; avoid potential allergens.
            """
        else:
            return """
## ⚠️ **Possible Oral Lichen Planus - Moderate Confidence**

**Recommended Actions:**
- 🏥 Schedule oral medicine consultation within 2-3 weeks.
- 👀 Monitor for white patches or erosions.
- 🚭 Avoid irritants like tobacco and alcohol.
            """

    elif predicted_class == "Tooth Discoloration":
        if confidence >= 70:
            return """
## 🟡 **Tooth Discoloration Detected - High Confidence**

**Overview:**
- Tooth discoloration can be extrinsic (surface stains) or intrinsic (internal changes), affecting aesthetics and sometimes indicating underlying issues.

**Symptoms and Risks:**
- Yellow, brown, or gray teeth; spots or bands.
- Causes: Poor hygiene, staining foods (coffee, wine), tobacco, aging, trauma, or medications (e.g., tetracycline).

**Professional Options:**
- 🦷 Professional cleaning and polishing to remove extrinsic stains.
- 💡 Teeth whitening with hydrogen peroxide (in-office or at-home trays) for vital teeth.
- 🔍 Evaluate for erosion or internal issues; veneers or crowns for severe cases.

**Treatment Considerations:**
- Only natural teeth whiten; restorations do not.
- For erosion-related discoloration, manage acid exposure (e.g., from diet or reflux).

**Prevention:**
- Brush and floss daily; 6-month cleanings.
- Limit staining agents; use straws for beverages.
- For children, monitor fluoride intake to prevent fluorosis.
            """
        else:
            return """
## ⚠️ **Possible Tooth Discoloration - Moderate Confidence**

**Recommended Actions:**
- 🏥 Schedule dental cleaning and evaluation.
- 🍷 Reduce staining substances like coffee or tobacco.
- 🪥 Improve brushing and flossing routine.
            """

    elif predicted_class == "Hypodontia":
        if confidence >= 70:
            return """
## 🔵 **Hypodontia (Missing Teeth) Detected - High Confidence**

**Overview:**
- Hypodontia is the congenital absence of 1-5 teeth (excluding third molars), affecting 2-8% of the population, often genetic or environmental.

**Symptoms and Risks:**
- Gaps in dentition, affecting chewing, speech, and aesthetics; may cause malocclusion or jaw issues.
- Associated with syndromes (e.g., ectodermal dysplasia) or isolated.

**Comprehensive Care:**
- 🏥 Consult prosthodontist or orthodontist for multidisciplinary planning.
- 📊 Panoramic X-ray and genetic evaluation if syndromic.
- 📝 Treatment plan based on age, number of missing teeth, and occlusion.

**Treatment Options:**
- Orthodontics (braces/Invisalign) to close spaces or prepare for prosthetics.
- Dental implants (preferred for adults) or bridges for replacement.
- Removable partial dentures for temporary or multiple missing teeth.

**Considerations:**
- 🍎 Functional adaptations for chewing and speech.
- 😃 Psychological support for aesthetic concerns.
- Bone grafting if needed for implants.

**Management:**
- Regular monitoring for dental development in children.
- Preserve space with retainers post-orthodontics.
            """
        else:
            return """
## ⚠️ **Possible Hypodontia - Moderate Confidence**

**Recommended Actions:**
- 🏥 Schedule comprehensive dental evaluation.
- 📊 Request panoramic X-ray for confirmation.
- 💭 Consider functional and aesthetic implications.
            """

    else:
        return "## ℹ️ No specific recommendations available for this condition."


# --- Use your provided model definition and transforms ---
test_transforms = A.Compose([
    A.Resize(height=300, width=300),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
], p=1.0)

class_names = ['Calculus', 'Caries', 'Gingivitis', 'Mouth Ulcer', 'Oral Cancer', 'Oral Lichen Planus', 'Tooth Discoloration', 'Hypodontia']

class CustomEfficientnet(nn.Module):
    def __init__(self, num_classes):
        super(CustomEfficientnet, self).__init__()
        self.model = timm.create_model('efficientnet_b4', pretrained=True, drop_path_rate=0.15)
        # Unfreeze layers
        for param in self.model.parameters():
            param.requires_grad = True
        # Replace classifier
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(

            nn.Linear(in_features, 1024),
            nn.BatchNorm1d(1024), nn.ReLU(),
            nn.Dropout(0.6),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),

            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.4),

            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.4),

            nn.Linear(128, num_classes)
        )
    def forward(self, x):
        return self.model(x)

# --- Load your trained model ---
num_classes = len(class_names)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model = CustomEfficientnet(num_classes=num_classes)
best_model.load_state_dict(torch.load("best_model.pth", map_location=device))
best_model.to(device)
best_model.eval()
print('Model loaded and ready.')

# --- Grad-CAM Implementation ---
# Adapted from Keras example and PyTorch adaptations :cite[2]:cite[4]
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_forward()
        self.hook_backward()

    def hook_forward(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()
        self.target_layer.register_forward_hook(forward_hook)

    def hook_backward(self):
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        self.target_layer.register_full_backward_hook(backward_hook)

    def generate_heatmap(self, input_tensor, target_class=None):
        model_output = self.model(input_tensor)
        if target_class is None:
            target_class = torch.argmax(model_output, dim=1)
        self.model.zero_grad()
        loss = model_output[0, target_class]
        loss.backward(retain_graph=True)

        gradients = self.gradients
        activations = self.activations
        weights = torch.mean(gradients, dim=(2, 3), keepdim=True)
        cam = torch.sum(weights * activations, dim=1, keepdim=True)
        cam = torch.relu(cam)
        cam = cam - torch.min(cam)
        cam = cam / torch.max(cam)
        return cam.squeeze().cpu().numpy()


target_layer = best_model.model.blocks[-1]
grad_cam = GradCAM(model=best_model, target_layer=target_layer)

# --- Function to process image and generate outputs ---
def predict_and_visualize(input_image):
    # Convert Gradio image (numpy array) to PIL Image
    if isinstance(input_image, np.ndarray):
        input_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
    original_image = input_image.copy()

    # Preprocess the image for the model
    image_np = np.array(input_image)
    transformed = test_transforms(image=image_np)
    input_tensor = transformed["image"].unsqueeze(0).to(device)  # Add batch dimension

    # Model prediction - Get probabilities for all classes
    with torch.no_grad():
        outputs = best_model(input_tensor)
        probabilities = torch.nn.functional.softmax(outputs[0], dim=0)

        # Save all class probabilities (for reference if needed)
        all_probs = {class_names[i]: float(probabilities[i].item()*100)
                     for i in range(len(class_names))}

        # Get top-2 from full distribution
        top2_probs, top2_indices = torch.topk(probabilities, 2)

        # Top1 prediction (keep original prob, normalized to all classes)
        predicted_class_idx = top2_indices[0].item()
        predicted_label = class_names[predicted_class_idx]
        predicted_probability = probabilities[predicted_class_idx].item()

        # Top2 prediction
        predicted_class_idx2 = top2_indices[1].item()
        predicted_label2 = class_names[predicted_class_idx2]
        predicted_probability2 = probabilities[predicted_class_idx2].item()

    # Grad-CAM heatmap
    heatmap = grad_cam.generate_heatmap(input_tensor, target_class=predicted_class_idx)

    # Visualization
    original_image_np = np.array(original_image.resize((300, 300))) / 255.0
    heatmap_resized = np.uint8(255 * heatmap)
    heatmap_resized = Image.fromarray(heatmap_resized).resize((300, 300), Image.LANCZOS)
    heatmap_resized = np.array(heatmap_resized) / 255.0

    plt.figure(figsize=(10, 5))
    plt.imshow(original_image_np)
    plt.imshow(heatmap_resized, cmap='jet', alpha=0.5)
    plt.axis('off')
    plt.title(f"Grad-CAM: {predicted_label} ({predicted_probability:.2%})", fontsize=14)
    plt.tight_layout()
    plt.savefig("grad_cam_output.png", bbox_inches='tight', pad_inches=0)
    plt.close()

    # Recommendations for Top1
    confidence = predicted_probability * 100
    recommendations = get_disease_recommendations(predicted_label, confidence)

    # Return output
    return f"""
## 🎯 AI Diagnosis Results

### 🥇 Primary Prediction
**Disease:** {predicted_label}
**Confidence Level:** {predicted_probability*100:.2f}%

### 🥈 Secondary Prediction
**Disease:** {predicted_label2}
**Confidence Level:** {predicted_probability2*100:.2f}%

---
*Note: Recommendations are based on the primary prediction.*
""", "grad_cam_output.png", recommendations


# --- Create the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🦷 Oral Disease Classification with Grad-CAM")
    gr.Markdown("### AI-Powered Diagnosis with Top 2 Predictions")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="📷 Upload Oral Image", type="pil")
            submit_btn = gr.Button("🔍 Predict", variant="primary")

        with gr.Column():
            # Display predictions in a more organized way
            with gr.Group():
                gr.Markdown("### 📊 Prediction Results")
                label_output = gr.Markdown(label="Prediction Results")

            with gr.Group():
                gr.Markdown("### 🔍 Grad-CAM Visualization")
                grad_cam_output = gr.Image(label="Heatmap Visualization", height=300)

            with gr.Accordion("📋 Detailed Recommendations", open=False):
                recommendations_output = gr.Markdown()

    submit_btn.click(
        fn=predict_and_visualize,
        inputs=image_input,
        outputs=[label_output, grad_cam_output, recommendations_output]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch(share=True)

Model loaded and ready.
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://608f8cfb482993d3b7.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
