In [1]:
# Required libraries for the app
%%writefile dermAI.py
import streamlit as st
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
from io import BytesIO
import firebase_admin
from firebase_admin import credentials, auth
import time
import matplotlib.pyplot as plt

# Styling Enhancements
st.markdown("""
    <style>
    .stButton>button {
        background-color: #4CAF50;
        color: white;
        font-size: 16px;
        padding: 12px;
        border-radius: 5px;
    }
    .stTextInput>div>input {
        border-radius: 10px;
        padding: 10px;
    }
    </style>
""", unsafe_allow_html=True)

# Title and info
st.title("🧠 DermAI - Skin Disease Classifier")
st.markdown("Upload a skin lesion image (HAM10000 format) and get an AI prediction along with educational information.")

# Firebase initialization and allowing access to Firebase authentication services using JSON key
@st.cache_resource
def initialize_firebase():
    try:
        firebase_admin.get_app()
    except ValueError:
        cred = credentials.Certificate('/content/derm-ai-6837e-firebase-adminsdk-fbsvc-40d6feb75b.json')
        firebase_admin.initialize_app(cred)

initialize_firebase()

# Load model from hugging face
@st.cache_resource
def load_model():
    model_name = "ALM-AHME/convnextv2-large-1k-224-finetuned-Lesion-Classification-HAM10000-AH-60-20-20"
    model = AutoModelForImageClassification.from_pretrained(model_name)
    processor = AutoImageProcessor.from_pretrained(model_name)
    labels = model.config.id2label
    return model, processor, labels

model, processor, labels = load_model()

# Educational descriptions for the labels
def get_education_description(disease_label):
    descriptions = {
        "mel": "Melanoma is a serious and aggressive form of skin cancer that originates in pigment-producing cells called melanocytes. It can occur anywhere on the skin and often appears as a dark or irregularly shaped mole.",
        "nv": "A melanocytic nevus is a benign mole or birthmark formed by clusters of melanocytes (pigment-producing cells). While they are usually harmless, they can sometimes change over time and should be monitored.",
        "bcc": "Basal cell carcinoma is the most common type of skin cancer. It typically occurs in sun-exposed areas of the skin and often appears as a small, shiny bump or a pink growth.",
        "akiec": "Actinic keratosis is a pre-cancerous skin lesion caused by prolonged sun exposure. It often appears as dry, scaly patches on sun-exposed areas, and if left untreated, it may develop into skin cancer.",
        "bkl": "Benign keratosis refers to non-cancerous growths on the skin, including solar lentigo, seborrheic keratosis, and lichen planus-like keratosis. These are often associated with aging and sun exposure.",
        "df": "Dermatofibromas are benign skin growths, typically firm, raised, and brownish in color. They are made up of fibrous tissue and are usually harmless but can occasionally become irritated.",
        "vasc": "Vascular lesions are caused by abnormalities in blood vessels and can range from harmless conditions like spider veins to more serious vascular malformations. They can appear as red or purple spots on the skin.",
    }
    return descriptions.get(disease_label, "No description available for this disease.")

# Function to handle user login/signup
def authenticate_user():
    st.subheader("Sign Up or Log In")
    email = st.text_input("Email")
    password = st.text_input("Password", type="password")

    if st.button("Login"):
        try:
            user = auth.get_user_by_email(email)
            st.success("Logged in successfully!")
        except Exception as e:
            st.error(f"Error: {e}")

authenticate_user()

# Image Preprocessing (Optional for Cropping)
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

confidence_percent = 0  # Default to 0 before processing
predicted_label = ""  # Default empty string before prediction

if uploaded_file:
    image = Image.open(uploaded_file).convert("RGB")

    # Add cropping functionality here, if necessary
    st.image(image, caption="Uploaded Image", use_column_width=True)
    st.write("You can adjust the image if needed.")

    # Progress Bar for Image Processing
    with st.spinner("Classifying..."):
        progress_bar = st.progress(0)
        for percent_complete in range(100):
            time.sleep(0.01)  # Simulating processing time
            progress_bar.progress(percent_complete + 1)

        # Inference
        inputs = processor(images=image, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.nn.functional.softmax(outputs.logits, dim=1)
            confidence, pred_class = torch.max(probs, dim=1)
            predicted_label = labels[pred_class.item()]
            confidence_percent = float(confidence.item()) * 100

        # Show prediction with confidence
        st.success(f"🩺 **Prediction:** {predicted_label}")
        st.success(f"    **Confidence:** ({confidence_percent:.2f}%)")

        # Model Confidence Visualization (Bar chart)
        def plot_confidence(probs, labels):
            fig, ax = plt.subplots(figsize=(3, 2))
            ax.barh(labels, probs)
            ax.set_xlabel('Confidence (%)')
            ax.set_title('Class Confidence')
            st.pyplot(fig)

        # Display class probabilities
        if confidence_percent > 0:
            probs = torch.nn.functional.softmax(outputs.logits, dim=1).squeeze().cpu().detach().numpy()
            plot_confidence(probs, list(labels.values()))

        # Risk Category Based on Confidence
        def categorize_risk(confidence_percent):
            if confidence_percent > 80:
                return "High Risk"
            elif confidence_percent > 50:
                return "Medium Risk"
            else:
                return "Low Risk"

        if confidence_percent > 0:  # Make sure confidence_percent is calculated
            risk_category = categorize_risk(confidence_percent)
            st.write(f"**Risk Category:** {risk_category}")

        # Show educational description
        description = get_education_description(predicted_label)
        st.write(f"**Description:** {description}")

        # Suggest next steps
        st.write("**Suggested Next Step:** Please consult a dermatologist for further evaluation of the skin lesion.")

        # Generate downloadable report
        if st.button("Download Report", key="download_report_button"):
            report_text = f"Prediction Report\n\nDisease: {predicted_label}\nConfidence: {confidence_percent:.2f}%\nDescription: {description}\nSuggested Next Step: Consult a dermatologist."

            # Create a PDF or text report for download
            pdf_report = BytesIO()
            pdf_report.write(report_text.encode('utf-8'))
            pdf_report.seek(0)

            st.download_button(
                label="Download Report",
                data=pdf_report,
                file_name="prediction_report.txt",
                mime="text/plain",
                key="download_button"
            )

# Responsive columns layout for UI
st.columns([3, 1, 3])


Writing dermAI.py


In [2]:
!pip install streamlit

Collecting streamlit
  Downloading streamlit-1.46.1-py3-none-any.whl.metadata (9.0 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Downloading streamlit-1.46.1-py3-none-any.whl (10.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.1/10.1 MB[0m [31m83.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m104.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl (79 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.1/79.1 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
!pip install streamlit pyngrok

Collecting pyngrok
  Downloading pyngrok-7.2.12-py3-none-any.whl.metadata (9.4 kB)
Downloading pyngrok-7.2.12-py3-none-any.whl (26 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.12


In [4]:
from pyngrok import ngrok
from google.colab import userdata

# Kill existing tunnels if rerunning
ngrok.kill()

# Set the ngrok authentication token from Colab secrets
# Fetch the secret before starting the background process
NGROK_AUTH_TOKEN = userdata.get('NGROK_AUTH_TOKEN')
ngrok.set_auth_token(NGROK_AUTH_TOKEN)

# Run streamlit in background
!streamlit run dermAI.py &> /dev/null &

# Open the tunnel and get the public URL
# Specifying the address as a string including the port
public_url = ngrok.connect('8501')
print(f"Streamlit app is live at: {public_url}")

Streamlit app is live at: NgrokTunnel: "https://45b5d44b1132.ngrok-free.app" -> "http://localhost:8501"
