# This script builds a Medical AI Assistant that combines:

### Chest X-ray classification (Normal vs Pneumonia using ResNet18)

### Medical report summarization (using HuggingFace Transformers)

### Guideline-based Q&A with Retrieval-Augmented Generation (RAG)

### Interactive Gradio web app interface

In [None]:
import os, glob, re, shutil, string, torch
import gradio as gr
from PIL import Image
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline as hf_pipeline
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain.docstore.document import Document

# Paths 
ROOT = r"D:\\MedicalAI-Assistant"
DATA_DIR = os.path.join(ROOT, 'data')
ART_DIR = os.path.join(ROOT, 'artifacts')
XWEIGHTS = os.path.join(ART_DIR, 'xray_model.pth')
TCAL = os.path.join(ART_DIR, 'xray_temp.pt')
SUM_DIR = os.path.join(ART_DIR, 'summarizer')
PDF_PATH = os.path.join(DATA_DIR, 'guidelines.pdf')
INDEX_DIR = os.path.join(ART_DIR, 'faiss_index')
os.makedirs(ART_DIR, exist_ok=True)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
CLASS_NAMES = ['NORMAL', 'PNEUMONIA']

#  X-ray model 
from torchvision import models, transforms
from torch import nn

def load_resnet18(weights_path: str, num_classes=2):
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    in_feats = model.fc.in_features
    model.fc = nn.Sequential(nn.Dropout(0.2), nn.Linear(in_feats, num_classes))
    sd = torch.load(weights_path, map_location=DEVICE)
    if isinstance(sd, dict) and 'state_dict' in sd:
        model.load_state_dict(sd['state_dict'], strict=False)
    else:
        model.load_state_dict(sd, strict=False)
    model.eval().to(DEVICE)
    return model

_to_resnet = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

clf = load_resnet18(XWEIGHTS)
T = torch.load(TCAL, map_location='cpu').get('T', 1.0) if os.path.exists(TCAL) else 1.0

def predict_xray(img):
    if isinstance(img, (str, os.PathLike)):
        img = Image.open(img).convert('RGB')
    x1 = _to_resnet(img).unsqueeze(0)
    x2 = _to_resnet(img.transpose(Image.FLIP_LEFT_RIGHT)).unsqueeze(0)
    X = torch.cat([x1, x2], 0).to(DEVICE)
    with torch.no_grad():
        logits = clf(X).mean(0, keepdim=True)
        logits = logits / T
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
        idx = int(probs.argmax())
    return idx, float(probs[idx]), probs.tolist()

UNCERTAIN_THRESH = 0.55

def decision_text(pred_idx, conf):
    label = CLASS_NAMES[pred_idx]
    if conf < UNCERTAIN_THRESH:
        return f"Normal — No {label}. Doesn't need clinical correlation."
    else:
        return f"Prediction: {label}"


NORMAL_REPORT = (
    "Chest X-ray (PA):\n"
    "• Cardiomediastinal silhouette: Within normal limits.\n"
    "• Lungs: No focal consolidation identified. No pleural effusion.\n"
    "• Pneumothorax: Not seen.\n"
    "• Bones/soft tissues: Unremarkable.\n\n"
    "Impression: No acute cardiopulmonary abnormality."
)

# Summarizer 
if os.path.isdir(SUM_DIR):
    sum_tok = AutoTokenizer.from_pretrained(SUM_DIR)
    sum_mod = AutoModelForSeq2SeqLM.from_pretrained(SUM_DIR)
else:
    sum_tok = AutoTokenizer.from_pretrained('t5-small')
    sum_mod = AutoModelForSeq2SeqLM.from_pretrained('t5-small')

SUM = hf_pipeline('summarization', model=sum_mod, tokenizer=sum_tok)

#  RAG 
EMB = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

def clean_text(t: str) -> str:
    t = re.sub(r'\s+', ' ', t)
    t = re.sub(r'[^A-Za-z0-9.,;:%°()\-–/\[\] ]+', ' ', t)
    t = re.sub(r'\s+', ' ', t)
    return t.strip()

# Build FAISS if missing
if not os.path.isdir(INDEX_DIR) and os.path.exists(PDF_PATH):
    docs = PyPDFLoader(PDF_PATH).load()
    chunks = []
    for d in docs:
        c = clean_text(d.page_content)
        for i in range(0, len(c), 800):
            chunks.append(c[i:i+800])
    vs = FAISS.from_documents([Document(page_content=x) for x in chunks], EMB)
    vs.save_local(INDEX_DIR)

RETRIEVER = FAISS.load_local(
    INDEX_DIR, EMB, allow_dangerous_deserialization=True
).as_retriever(search_kwargs={'k': 3})

llm_tok = AutoTokenizer.from_pretrained('google/flan-t5-small')
llm_mod = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-small')
llm_pipe = hf_pipeline(
    'text2text-generation',
    model=llm_mod,
    tokenizer=llm_tok,
    max_new_tokens=160,
    do_sample=False,
    temperature=0.0,
    repetition_penalty=1.2
)
LLM = HuggingFacePipeline(pipeline=llm_pipe)

PROMPT = PromptTemplate(
    input_variables=['context', 'question'],
    template=(
        "You are a medical guideline assistant. Answer ONLY from the context.\n"
        "If the answer is not in the context, say: 'I don't know'.\n\n"
        "Context:\n{context}\n\nQuestion: {question}\nAnswer:"
    )
)

QA = RetrievalQA.from_chain_type(
    llm=LLM,
    retriever=RETRIEVER,
    chain_type='stuff',
    chain_type_kwargs={'prompt': PROMPT},
    return_source_documents=True,
)

#  Gibberish guard 
def looks_like_gibberish(s: str) -> bool:
    if not s or len(s) < 8:
        return True
    letters = sum(c.isalpha() for c in s)
    printable = sum(c in string.printable for c in s)
    return (letters / len(s) < 0.4) or (printable / len(s) < 0.9)

#  Gradio UI 
with gr.Blocks() as demo:
    gr.Markdown('# 🩺 Medical AI Assistant — X-ray + Reports + RAG')
    with gr.Row():
        xray = gr.Image(label='Upload Chest X-ray', type='pil')
        report = gr.Textbox(lines=12, label='Paste medical report / notes')
        question = gr.Textbox(label='Ask a medical guideline question')
        btn = gr.Button('Analyze')

    xray_out = gr.Textbox(label='X-ray result')
    summary_out = gr.Textbox(label='Report summary')
    answer_out = gr.Textbox(label='Guideline answer (RAG)')

    # define pipeline function
    def run_pipeline(xray_image, report_text, question):
        # X-ray prediction
        pred_idx, prob, _ = predict_xray(xray_image)
        xray_res = decision_text(pred_idx, prob)

        # Report summarization
        summary = ''
        if report_text and report_text.strip():
            try:
                summary = SUM(report_text, max_length=120, min_length=20, do_sample=False)[0]['summary_text']
            except Exception as e:
                summary = f'Summarization error: {e}'

        # RAG answer
        answer = ''
        if question and question.strip():
            try:
                res = QA.invoke({'query': question})
                answer = res['result'] if isinstance(res, dict) else str(res)

                if looks_like_gibberish(answer):
                    docs = (res.get('source_documents') if isinstance(res, dict) else []) or RETRIEVER.get_relevant_documents(question)
                    snippets = [d.page_content.strip()[:500] for d in docs[:3] if d.page_content]
                    answer = 'I could not form a clear answer. Relevant guideline snippets:\n\n' + ("\n\n---\n\n".join(snippets) if snippets else 'No guideline text found.')
            except Exception as e:
                answer = f'RAG error: {e}'

        return xray_res, summary, answer

    # bind the button inside the Blocks
    btn.click(run_pipeline, inputs=[xray, report, question], outputs=[xray_out, summary_out, answer_out])

if __name__ == '__main__':
    demo.launch()


Device set to use cpu
Device set to use cpu
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.


Both `max_new_tokens` (=256) and `max_length`(=120) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)
The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
