In [1]:
!pip install -q transformers datasets peft accelerate bitsandbytes trl sentence-transformers faiss-cpu streamlit opencv-python-headless sacremoses

!pip install -q flash-attn --no-build-isolation

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.9/518.9 kB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m100.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.0/9.0 MB[0m [31m130.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m48.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m123.8 MB/s[0m eta [36m0:00:00[0m00:01[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m83.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for flash-attn 

In [2]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value = user_secrets.get_secret("HF_token")
login(token=secret_value)

In [3]:
#Using the PubMedQA dataset (optional, I used this previously with BioGPT)

from datasets import load_dataset
dataset = load_dataset("pubmed_qa", "pqa_labeled", split="train")  

def format_example(example):
    return {
        "text": f"You are a medical expert. Provide detailed explanation including causes, symptoms, treatments. "
                f"Question: {example['question']} Context: {example.get('context', {}).get('contexts', [''])[0]} "
                f"Answer: {example.get('long_answer', example.get('answer', ''))}"
    }

formatted_dataset = dataset.map(format_example)  

README.md: 0.00B [00:00, ?B/s]

pqa_labeled/train-00000-of-00001.parquet:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [5]:
#Connecting skin image folders plus a bit of error handling

import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np

IMAGE_DIR = "/kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_1"
MASK_DIR = "/kaggle/input/ham10000-lesion-segmentations/HAM10000_segmentations_lesion_tschandl" 

class SkinDataset(Dataset):
    def __init__(self, image_dir, mask_dir):
        if not os.path.exists(image_dir):
            print(f"ERROR: Path {image_dir} not found!")
            self.images = []
        else:
            self.images = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.image_dir, img_name)
        
        mask_name = img_name.replace('.jpg', '_segmentation.png')
        mask_path = os.path.join(self.mask_dir, mask_name)
        
        img = cv2.imread(img_path)
        img = cv2.resize(img, (128, 128)) 
        
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            mask = np.zeros((128, 128))
        else:
            mask = cv2.resize(mask, (128, 128))
            
        img_tensor = torch.from_numpy(img).permute(2,0,1).float() / 255.0
        mask_tensor = torch.from_numpy(mask).unsqueeze(0).float() / 255.0
        
        return img_tensor, mask_tensor

full_dataset = SkinDataset(IMAGE_DIR, MASK_DIR)

if len(full_dataset) == 0:
    print("No images found. Check your IMAGE_DIR path.")
else:
    num_to_use = min(1000, len(full_dataset))
    train_indices = np.arange(num_to_use) 
    train_subset = Subset(full_dataset, train_indices)
    print(f"Success! Dataset ready with {len(train_subset)} images found in {IMAGE_DIR}")

Success! Dataset ready with 1000 images found in /kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_1


In [6]:
#Creating the architecture for identifying tumors

import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

class MedicalUNet(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(base.children())[:-3])         

        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.up4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        
        self.final = nn.Conv2d(16, 1, kernel_size=1)

    def forward(self, x):
        x = self.encoder(x) 
        
        x = F.relu(self.up1(x)) 
        x = F.relu(self.up2(x)) 
        x = F.relu(self.up3(x)) 
        x = F.relu(self.up4(x)) 
        
        return self.final(x)

model = MedicalUNet().cuda()



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 215MB/s]


In [13]:
#Training the U-Net on 1000 skin images

from torch.optim import Adam
from torch.nn import BCEWithLogitsLoss

train_loader = DataLoader(train_subset, batch_size=8, shuffle=True)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = BCEWithLogitsLoss()

model.train() 
for epoch in range(5):
    total_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.cuda(), masks.cuda()
        
        preds = model(imgs)
        loss = criterion(preds, masks)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f}")

torch.save(model.state_dict(), 'unet_skin.pth')
print("Model saved.")

Epoch 1 | Loss: 0.6110
Epoch 2 | Loss: 0.3034
Epoch 3 | Loss: 0.2768
Epoch 4 | Loss: 0.2561
Epoch 5 | Loss: 0.2375
Model saved.


In [18]:
#GPU Verification

import torch
import gc

if 'model' in globals(): del model
if 'text_model' in globals(): del text_model
gc.collect()
torch.cuda.empty_cache()

print(f"Is GPU Available? {torch.cuda.is_available()}")
print(f"GPU Name: {torch.cuda.get_device_name(0)}")

Is GPU Available? True
GPU Name: Tesla P100-PCIE-16GB


In [3]:
%%writefile app.py
import streamlit as st
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import cv2
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import os
import warnings
import re

warnings.filterwarnings('ignore')

st.set_page_config(page_title="MediScan AI", layout="wide")

st.markdown("""
    <style>
    @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap');

    .stApp {
        background: linear-gradient(135deg, #f8f9fa 0%, #e9ecef 100%);
        font-family: 'Inter', sans-serif;
        color: #64748b;
    }

    /* Target all text elements on light backgrounds */
    .stApp p, .stApp span, .stApp label, .stApp div {
        color: #64748b !important;
    }

    h1, h2, h3 {
        color: #1e3a8a !important;
        font-weight: 600;
        letter-spacing: -0.5px;
    }

    .stTabs [data-baseweb="tab-list"] {
        background: transparent;
        border-bottom: 1px solid rgba(30, 58, 138, 0.15);
        gap: 32px;
        padding-bottom: 8px;
    }

    .stTabs [data-baseweb="tab"] {
        color: #64748b !important;
        font-weight: 500;
        padding: 12px 24px;
        transition: all 0.3s ease;
    }

    .stTabs [aria-selected="true"] {
        color: #1e3a8a !important;
        border-bottom: 3px solid #1e3a8a;
    }

    .stTextInput > div > div > input {
        background: rgba(255, 255, 255, 0.8);
        border: 1px solid rgba(30, 58, 138, 0.2);
        border-radius: 10px;
        color: #1e293b !important;
    }

    .stButton > button {
        background: #1e3a8a !important;
        color: white !important;
        border-radius: 10px;
        border: none;
    }

    .footer {
        position: fixed;
        bottom: 0;
        left: 0;
        right: 0;
        background: rgba(255, 255, 255, 0.9);
        backdrop-filter: blur(10px);
        text-align: center;
        padding: 12px;
        font-size: 13px;
        color: #64748b !important;
        border-top: 1px solid rgba(30, 58, 138, 0.1);
        z-index: 100;
    }

    .summary-box {
        background: rgba(255, 255, 255, 0.5);
        border-radius: 12px;
        padding: 20px;
        margin-top: 20px;
        border: 1px solid rgba(30, 58, 138, 0.1);
    }
    </style>
""", unsafe_allow_html=True)

class MedicalUNet(nn.Module):
    def __init__(self):
        super().__init__()
        base = models.resnet18(weights=None)
        self.encoder = nn.Sequential(*list(base.children())[:-3])
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.up4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.final = nn.Conv2d(16, 1, kernel_size=1)
    
    def forward(self, x):
        x = self.encoder(x)
        x = F.relu(self.up1(x))
        x = F.relu(self.up2(x))
        x = F.relu(self.up3(x))
        x = F.relu(self.up4(x))
        return self.final(x)

MEDICAL_KB = {
    "hypertension": {
        "definition": "Hypertension is persistently elevated blood pressure (≥130/80 mmHg).",
        "causes": "Primary causes include genetics, aging, obesity, high sodium intake, sedentary lifestyle.",
        "symptoms": "Often asymptomatic. When present: severe headaches, fatigue, vision changes.",
        "treatment": "Lifestyle: DASH diet, exercise. Medications: ACE inhibitors, diuretics.",
        "complications": "Heart attack, stroke, heart failure, kidney damage.",
    },
    "melanoma": {
        "definition": "Melanoma is aggressive skin cancer from melanocytes.",
        "causes": "UV radiation (sun, tanning beds). Risk factors: fair skin, family history.",
        "symptoms": "ABCDE: Asymmetry, Border irregularity, Color variation, Diameter >6mm, Evolution.",
        "treatment": "Surgical excision, immunotherapy, BRAF/MEK inhibitors.",
        "complications": "Metastasis to lymph nodes, lungs, liver, brain.",
    },
    "diabetes": {
        "definition": "Diabetes is chronic hyperglycemia due to insulin deficiency or resistance.",
        "causes": "Type 1: autoimmune. Type 2: insulin resistance from obesity, genetics.",
        "symptoms": "Polyuria, polydipsia, polyphagia, weight loss, fatigue.",
        "treatment": "Insulin, metformin, GLP-1 agonists, lifestyle changes.",
        "complications": "Retinopathy, nephropathy, neuropathy, CV disease.",
    },
    "skin_cancer": {
        "definition": "Skin cancer includes BCC, SCC, and melanoma.",
        "causes": "UV radiation, fair skin, family history.",
        "symptoms": "Pearly papules, firm red nodules, or irregular moles.",
        "treatment": "Surgical excision, Mohs surgery, cryotherapy.",
        "complications": "Local tissue destruction or metastasis.",
    },
    "skin_lesion": {
        "definition": "Skin lesion is a focal skin abnormality. Can be benign or malignant.",
        "causes": "Genetic factors, UV damage, or infections.",
        "symptoms": "Changes in color, size, or shape of a skin spot.",
        "treatment": "Observation, biopsy, or surgical removal.",
        "complications": "Risk of malignancy if left unmonitored.",
    }
}

def find_condition(query):
    clean_query = re.sub(r'[^a-zA-Z0-9\s]', '', query.lower())
    query_parts = clean_query.split()
    
    def is_close(word, target, threshold=0.75):
        if word == target: return True
        if abs(len(word) - len(target)) > 2: return False
        matches = sum(1 for a, b in zip(word, target) if a == b)
        shorter = word if len(word) < len(target) else target
        longer = target if len(word) < len(target) else word
        char_matches = sum(1 for char in set(shorter) if char in longer)
        score = (matches / max(len(word), len(target)) * 0.6) + (char_matches / max(len(word), len(target)) * 0.4)
        return score >= threshold

    keyword_map = {
    "hypertension": [
        "hypertension", "hypertention", "hipertension", "bloodpressure", "hypertenion", "htn",
        "hypertesion", "hypertention", "hipertenstion", "hipertention", "hpertension", "hyperension", 
        "highbp", "hibp", "hi-bp", "hi blood pressure", "blod pressure", "blud pressure", "blood presure", 
        "blood presher", "blood presure", "hypertenssion", "hypertensun", "hypertenshun", "hypertensin", 
        "hypetension", "hypetention", "hyper-tension", "hibloodpressure", "hiper-tension"
    ],
    "melanoma": [
        "melanoma", "malanoma", "melonoma", "melenoma", "malamona", "melanomma", "melanome", 
        "melanoman", "malenoma", "malonoma", "melanuma", "mellanoma", "melanmoma", "melenomma", 
        "milanoma", "melanooma", "melaanoma", "melanom", "mellanom", "melanomah", "melanom-a", 
        "melanomaa", "melanomwa", "melanomua", "melanomia", "melanomy", "malanomma", "melenona"
    ],
    "diabetes": [
        "diabetes", "dibetes", "diabetis", "diabettis", "diabeties", "diabettes", "diabeetus", 
        "diabetus", "diabete", "dyabetes", "dyabetis", "diabeetes", "diabitis", "diabets", 
        "diabeets", "diabetese", "dibetis", "dibeties", "dia-betes", "diabeti", "diabetties", 
        "diabetees", "diabeteas", "diabeteus", "diabetiss", "diabetez", "diabetiz", "diabeetez",
        "bloodsugar", "blodsugar", "bludsugar", "blod suger", "blood suger"
    ],
    "skin_cancer": [
        "skincancer", "skincancr", "carcinoma", "basalcell", "squamouscell", "skincansur", 
        "skincanser", "skincansar", "skin-cancer", "skin cancr", "skin canser", "skincancer", 
        "skn cancer", "skin cancsr", "skincancre", "skin-cancr", "basalcel", "bazalcell", 
        "basal-cell", "basal cell", "baselcell", "squamous", "squamus", "sqamous", "squamuscell", 
        "squamose", "squamos", "scc", "bcc", "epithelioma"
    ],
    "skin_lesion": [
        "lesion", "lession", "lesoin", "mole", "spot", "growth", "nevus", "leision", "leesion", 
        "leison", "leson", "lezhun", "leashun", "leashon", "leshin", "leshuon", "skin-lesion", 
        "skinlesion", "skinleison", "skin lesion", "moles", "moule", "mowl", "moal", "birthmark", 
        "nevi", "neveus", "nevous", "nivas", "niveus", "skingrowth", "skinspot"
    ]
}

    for cond, targets in keyword_map.items():
        for target in targets:
            if any(is_close(part, target) for part in query_parts) or target in clean_query:
                return cond, MEDICAL_KB[cond]
    return None, None

@st.cache_resource
def load_models():
    vision_model = None
    if torch.cuda.is_available():
        try:
            vision_model = MedicalUNet().cuda()
            if os.path.exists('/kaggle/working/unet_skin.pth'):
                vision_model.load_state_dict(torch.load('/kaggle/working/unet_skin.pth'))
                vision_model.eval()
        except: pass
    
    text_model, tokenizer = None, None
    try:
        if torch.cuda.is_available():
            bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
            text_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", quantization_config=bnb_config, device_map="auto", trust_remote_code=True)
            tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
            tokenizer.pad_token = tokenizer.eos_token
    except: pass
    return vision_model, text_model, tokenizer

vision_model, text_model, tokenizer = load_models()

def generate_answer(query, info):
    if not text_model or not tokenizer or not info: return None
    focus = f"{info['definition']} {info['symptoms']}"
    prompt = f"### Medical Question\n{query}\n\n### Relevant Medical Information\n{focus}\n\n### Answer\n"
    inputs = tokenizer(prompt, return_tensors="pt").to(text_model.device)
    with torch.no_grad():
        outputs = text_model.generate(**inputs, max_new_tokens=100, temperature=0.7)
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True).split("### Answer")[-1].strip()
    return answer

def generate_kb_answer(query, info):
    if not info: return None
    return f"**Definition:** {info['definition']}\n\n**Symptoms:** {info['symptoms']}"

st.title("MediScan AI")
st.markdown("<h3 style='margin-top:-12px;'>Medical Information Assistant</h3>", unsafe_allow_html=True)

tab1, tab2, tab3 = st.tabs(["Medical Q&A", "Image Analysis", "About"])

with tab1:
    st.markdown("### Ask Medical Questions")
    
    col1, col2 = st.columns([3, 1])
    with col1:
        query = st.text_input(
            "Question:",
            placeholder="e.g., What are the symptoms of melanoma?",
            label_visibility="collapsed"
        )
    with col2:
        btn = st.button("Ask", type="primary", use_container_width=True)
    
    with st.expander("Examples"):
        st.markdown("""
        - What are symptoms of hypertension?
        - How is diabetes treated?
        - What causes melanoma?
        - When to see doctor for skin lesion?
        """)
    
    if query and (btn or query):
        with st.spinner("Processing..."):
            cond, info = find_condition(query)
            
            if info:
                ai_ans = generate_answer(query, info) if text_model else None
                kb_ans = generate_kb_answer(query, info)
                
                st.markdown("---")
                
                if ai_ans:
                    st.markdown('<div class="ai-response">', unsafe_allow_html=True)
                    st.markdown(f"**AI Response:**\n\n{ai_ans}")
                    st.markdown('</div>', unsafe_allow_html=True)
                    
                    with st.expander("View Detailed Reference"):
                        st.markdown(kb_ans)
                else:
                    st.markdown(f"**Medical Information:**\n\n{kb_ans}")
                
                st.markdown("---")
                st.success("Educational information only. Consult healthcare professional.")
            else:
                st.warning("No info available. Try: hypertension, melanoma, diabetes, skin cancer, lesions.")

with tab2:
    st.markdown("### Lesion Segmentation")
    file = st.file_uploader("Upload skin image", type=['jpg', 'jpeg', 'png'], label_visibility="collapsed")
    
    if file and vision_model:
        try:
            bytes_img = np.asarray(bytearray(file.read()), dtype=np.uint8)
            img = cv2.imdecode(bytes_img, cv2.IMREAD_COLOR)
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            inp = cv2.resize(img, (128, 128))
            tensor = torch.from_numpy(inp).permute(2,0,1).float().unsqueeze(0).cuda() / 255.0
            
            with torch.no_grad():
                pred = vision_model(tensor)
                mask = (torch.sigmoid(pred).cpu().numpy()[0][0] > 0.5).astype(np.uint8) * 255
            
            pixels_detected = np.count_nonzero(mask)
            total_pixels = mask.size
            percentage = (pixels_detected / total_pixels) * 100

            col1, col2 = st.columns(2)
            with col1:
                st.image(img_rgb, use_container_width=True, caption="Original Image")
            with col2:
                st.image(mask, use_container_width=True, caption="Segment Detected")
            
            st.markdown(f"""
                <div class="summary-box">
                    <h4>Analysis Summary</h4>
                    <p><b>Lesion Coverage:</b> {percentage:.2f}% of the analyzed area.</p>
                    <p><b>What this means:</b> The AI has identified a distinct region of interest based on pixel contrast and texture. A higher percentage may indicate a larger surface involvement.</p>
                    <p style="color:#b91c1c !important; font-weight:600;">Disclaimer: This is an automated segmentation. Please consult a dermatologist for a clinical diagnosis.</p>
                </div>
            """, unsafe_allow_html=True)

        except Exception as e:
            st.error(f"Error: {e}")

with tab3:
    st.markdown("### About MediScan AI")
    st.markdown("""
    Educational AI combining Llama 3.2 3B Instruct with U-Net vision model.
    
    **Technology:**
    - Llama 3.2 3B Instruct
    - U-Net + ResNet18
    - HAM10000 dataset
    
    **Limitations:**
    - Research prototype. Not clinically validated.
    """)

st.markdown("""
    <div class="footer">
        This is not medical advice. Consult a specialist for proper consultation and diagnosis.
    </div>
""", unsafe_allow_html=True)

Overwriting app.py


In [4]:
#Used ngrok

!pip install pyngrok --quiet

from kaggle_secrets import UserSecretsClient
from pyngrok import ngrok

user_secrets = UserSecretsClient()
ngrok_token = user_secrets.get_secret("NGROK_AUTH_TOKEN")

!ngrok authtoken {ngrok_token}

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml                                


In [5]:
import time, os
from pyngrok import ngrok

!pkill -f streamlit

print("Starting Streamlit...")
os.system("nohup streamlit run app.py --server.port 8501 --server.headless true > logs.txt 2>&1 &")

time.sleep(30) 

print("Creating ngrok tunnel...")
public_url = ngrok.connect(8501, "http")
print("Interactive Streamlit app is live at:", public_url)
print("Open the URL above (wait 10-30s if needed). Widgets should now be clickable!")

Starting Streamlit...
Creating ngrok tunnel...
Interactive Streamlit app is live at: NgrokTunnel: "https://simon-nontraveling-gerundively.ngrok-free.dev" -> "http://localhost:8501"
Open the URL above (wait 10-30s if needed). Widgets should now be clickable!
