<a href="https://colab.research.google.com/github/katakamnikki07-spec/HositpalBot/blob/main/Submissionmain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install torch pandas scikit-learn sentencepiece gradio rouge-score sacrebleu faiss-cpu sentence-transformers --quiet

import os, re, time, random, json, math
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import sentencepiece as spm
from rouge_score import rouge_scorer
import sacrebleu
import gradio as gr


from google.colab import files
uploaded = files.upload()
CSV_PATH = list(uploaded.keys())[0]
print("Loaded:", CSV_PATH)

# ---------------------------
# 1) Load & clean
# ---------------------------
def clean_text(s: str) -> str:
    return re.sub(r"\s+", " ", str(s)).strip()

df = pd.read_csv(CSV_PATH).dropna().drop_duplicates().reset_index(drop=True)
assert {"question","answer"}.issubset(set(df.columns)), "CSV must have 'question' and 'answer' columns."
df["question"] = df["question"].map(clean_text)
df["answer"]   = df["answer"].map(clean_text)

train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df,   test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
print("Train/Val/Test sizes:", len(train_df), len(val_df), len(test_df))

Saving mle_screening_dataset.csv to mle_screening_dataset.csv
Loaded: mle_screening_dataset.csv
Train/Val/Test sizes: 13082 1635 1636


In [6]:
# ---------------------------
# 2) SentencePiece tokenizer
# ---------------------------
with open("spm_input.txt", "w", encoding="utf-8") as f:
    for t in pd.concat([train_df["question"], train_df["answer"]]):
        f.write(t + "\n")

VOCAB_SIZE = 8000
spm.SentencePieceTrainer.train(
    input="spm_input.txt", model_prefix="spm_medical",
    vocab_size=VOCAB_SIZE, model_type="bpe"
)
sp = spm.SentencePieceProcessor(model_file="spm_medical.model")

PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2
SPECIALS = ["<pad>","<bos>","<eos>"]
itos = SPECIALS + [sp.id_to_piece(i) for i in range(sp.vocab_size())]
stoi = {w:i for i,w in enumerate(itos)}

def encode_sp(text, max_len, add_bos=False, add_eos=False):
    ids = sp.encode(text, out_type=int)
    if add_bos: ids = [BOS_IDX] + ids
    if add_eos: ids = ids + [EOS_IDX]
    return ids[:max_len]

def pad_list(ids, max_len):
    return ids + [PAD_IDX]*(max_len - len(ids))

In [7]:
# ---------------------------
# 3) Dataset/Dataloader
# ---------------------------
MAX_SRC, MAX_TGT = 96, 128

class QADataset(Dataset):
    def __init__(self, frame):
        self.qs = frame["question"].tolist()
        self.as_ = frame["answer"].tolist()
    def __len__(self): return len(self.qs)
    def __getitem__(self, i):
        src = encode_sp(self.qs[i], MAX_SRC, add_bos=False, add_eos=False)
        tgt_core = encode_sp(self.as_[i], MAX_TGT-2, add_bos=False, add_eos=False)
        tgt_in  = [BOS_IDX] + tgt_core
        tgt_out = tgt_core + [EOS_IDX]
        return torch.tensor(src), torch.tensor(tgt_in), torch.tensor(tgt_out)

def collate(batch):
    srcs, tins, touts = zip(*batch)
    ms = max(len(x) for x in srcs)
    mt = max(len(y) for y in tins)
    src_pad  = [pad_list(x.tolist(), ms) for x in srcs]
    tin_pad  = [pad_list(y.tolist(), mt) for y in tins]
    tout_pad = [pad_list(z.tolist(), mt) for z in touts]
    return torch.tensor(src_pad), torch.tensor(tin_pad), torch.tensor(tout_pad)

BATCH=64
train_loader = DataLoader(QADataset(train_df), batch_size=BATCH, shuffle=True,  collate_fn=collate)
val_loader   = DataLoader(QADataset(val_df),   batch_size=BATCH, shuffle=False, collate_fn=collate)

In [8]:
# ---------------------------
# 4) Seq2Seq Transformer
# ---------------------------
class Seq2SeqTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, nlayers=3, dff=512, dropout=0.1):
        super().__init__()
        self.src_embed = nn.Embedding(vocab_size, d_model)
        self.tgt_embed = nn.Embedding(vocab_size, d_model)
        self.pos_src   = nn.Embedding(2048, d_model)
        self.pos_tgt   = nn.Embedding(2048, d_model)
        self.tf = nn.Transformer(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=nlayers, num_decoder_layers=nlayers,
            dim_feedforward=dff, dropout=dropout, batch_first=True
        )
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt_in):
        bs, sl = src.shape
        bs, tl = tgt_in.shape
        ps = torch.arange(sl, device=src.device).unsqueeze(0).expand(bs, sl)
        pt = torch.arange(tl, device=tgt_in.device).unsqueeze(0).expand(bs, tl)
        se = self.src_embed(src) + self.pos_src(ps)
        te = self.tgt_embed(tgt_in) + self.pos_tgt(pt)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tl).to(src.device)
        out = self.tf(
            se, te, tgt_mask=tgt_mask,
            src_key_padding_mask=(src==PAD_IDX),
            tgt_key_padding_mask=(tgt_in==PAD_IDX),
            memory_key_padding_mask=(src==PAD_IDX)
        )
        return self.proj(out)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = sp.vocab_size() + len(SPECIALS)
model = Seq2SeqTransformer(vocab_size).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)


In [9]:
# ---------------------------
# 5) Train few epochs
# ---------------------------
EPOCHS = 60
for ep in range(1, EPOCHS+1):
    model.train(); total=0
    for src, tin, tout in train_loader:
        src, tin, tout = src.to(device), tin.to(device), tout.to(device)
        opt.zero_grad()
        logits = model(src, tin)
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), tout.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        total += loss.item()
    tr = total/len(train_loader)

    model.eval(); vtotal=0
    with torch.no_grad():
        for src, tin, tout in val_loader:
            src, tin, tout = src.to(device), tin.to(device), tout.to(device)
            logits = model(src, tin)
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), tout.reshape(-1))
            vtotal += loss.item()
    va = vtotal/len(val_loader)
    print(f"Epoch {ep}: train={tr:.3f} val={va:.3f}")

# Save checkpoint
CKPT_PATH = "medical_qa.pt"
torch.save({
    "state_dict": model.state_dict(),
    "config": {"d_model":256,"nhead":8,"nlayers":3,"dff":512,"dropout":0.1,"vocab_size":vocab_size},
    "spm_model": "spm_medical.model"
}, CKPT_PATH)
print("Saved checkpoint:", CKPT_PATH)



Epoch 1: train=6.245 val=5.148
Epoch 2: train=4.902 val=4.557
Epoch 3: train=4.444 val=4.229
Epoch 4: train=4.142 val=3.994
Epoch 5: train=3.904 val=3.814
Epoch 6: train=3.703 val=3.658
Epoch 7: train=3.536 val=3.532
Epoch 8: train=3.386 val=3.426
Epoch 9: train=3.251 val=3.337
Epoch 10: train=3.130 val=3.253
Epoch 11: train=3.018 val=3.184
Epoch 12: train=2.919 val=3.121
Epoch 13: train=2.825 val=3.073
Epoch 14: train=2.738 val=3.027
Epoch 15: train=2.658 val=2.986
Epoch 16: train=2.586 val=2.948
Epoch 17: train=2.520 val=2.921
Epoch 18: train=2.456 val=2.898
Epoch 19: train=2.398 val=2.885
Epoch 20: train=2.343 val=2.868
Epoch 21: train=2.293 val=2.853
Epoch 22: train=2.242 val=2.847
Epoch 23: train=2.199 val=2.837
Epoch 24: train=2.155 val=2.837
Epoch 25: train=2.115 val=2.829
Epoch 26: train=2.077 val=2.825
Epoch 27: train=2.040 val=2.824
Epoch 28: train=2.004 val=2.829
Epoch 29: train=1.971 val=2.824
Epoch 30: train=1.938 val=2.829
Epoch 31: train=1.906 val=2.833
Epoch 32: train=1

In [10]:
# ---------------------------
# 6) Decoding: Beam search & helpers
# ---------------------------
def ids_to_text(ids):

    clean = [i for i in ids if i not in (PAD_IDX, BOS_IDX, EOS_IDX)]
    return sp.decode(clean)

@torch.no_grad()
def greedy_answer(question, max_len=128):
    model.eval()
    src_ids = encode_sp(question, MAX_SRC, add_bos=False, add_eos=False)
    src = torch.tensor([src_ids], device=device)
    tgt = torch.tensor([[BOS_IDX]], device=device)
    for _ in range(max_len-1):
        logits = model(src, tgt)
        nxt = logits[:, -1, :].argmax(-1, keepdim=True)
        tgt = torch.cat([tgt, nxt], dim=1)
        if nxt.item() == EOS_IDX: break
    return ids_to_text(tgt[0].tolist())

@torch.no_grad()
def beam_search(question, beam=3, max_len=128, context_prefix=None):
    model.eval()
    if context_prefix:
        # prepend retrieved context to the source (RAG)
        rag_src = f"{context_prefix}\nQuestion: {question}"
        src_ids = encode_sp(rag_src, MAX_SRC, add_bos=False, add_eos=False)
    else:
        src_ids = encode_sp(question, MAX_SRC, add_bos=False, add_eos=False)

    src = torch.tensor([src_ids], device=device)
    beams = [(torch.tensor([[BOS_IDX]], device=device), 0.0)]
    for _ in range(max_len-1):
        new_beams=[]
        for seq,score in beams:
            if seq[0,-1]==EOS_IDX:
                new_beams.append((seq,score)); continue
            logits = model(src, seq)
            probs = torch.log_softmax(logits[:,-1,:], dim=-1)
            topk = torch.topk(probs, beam)
            for idx, s in zip(topk.indices[0], topk.values[0]):
                new_beams.append((torch.cat([seq, idx.view(1,1)], dim=1), score + s.item()))
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam]
    best_ids = beams[0][0][0].tolist()
    return ids_to_text(best_ids)

In [11]:
# ---------------------------
# 7) Evaluation (ROUGE-L & BLEU)
# ---------------------------
def evaluate(sample_n=150, use_beam=True, use_rag=False, k=3):
    subset = test_df.sample(min(sample_n, len(test_df)), random_state=0)
    rouge_scores, bleu_scores = [], []
    for _, row in subset.iterrows():
        q, ref = row["question"], row["answer"]
        if use_rag:
            ctx, _ = retrieve_context(q, top_k=k)
            ctx_text = " ".join(ctx)
            hyp = beam_search(q, beam=3, context_prefix=ctx_text) if use_beam else greedy_answer(q)  # greedy ignores ctx
        else:
            hyp = beam_search(q, beam=3) if use_beam else greedy_answer(q)
        r = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True).score(ref, hyp)["rougeL"].fmeasure
        b = sacrebleu.corpus_bleu([hyp], [[ref]]).score
        rouge_scores.append(r); bleu_scores.append(b)
    print(f"Eval (beam={use_beam}, RAG={use_rag}, k={k}) -> ROUGE-L: {np.mean(rouge_scores):.4f}  BLEU: {np.mean(bleu_scores):.2f}")

In [12]:
# ---------------------------
# 8) Build Retriever (Embeddings + FAISS, with TF-IDF fallback)
# ---------------------------
USE_EMBED = True
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

retrieval_corpus = []
for i, row in df.iterrows():

    doc = f"{row['question']} || {row['answer']}"
    retrieval_corpus.append(doc)

retriever_info = {"mode": None}

try:
    from sentence_transformers import SentenceTransformer
    import faiss

    emb_model = SentenceTransformer(EMB_MODEL_NAME)
    emb = emb_model.encode(retrieval_corpus, batch_size=64, show_progress_bar=True, normalize_embeddings=True)
    dim = emb.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(emb.astype("float32"))

    retriever_info["mode"] = "embeddings"
    retriever_info["emb_model_name"] = EMB_MODEL_NAME
    retriever_info["faiss_index_size"] = index.ntotal
    print("Retriever: sentence-transformers + FAISS (IP)")

    def retrieve_context(query, top_k=3):
        q_emb = emb_model.encode([query], normalize_embeddings=True)
        D, I = index.search(q_emb.astype("float32"), top_k)
        hits = [retrieval_corpus[i] for i in I[0]]
        return hits, D[0].tolist()

except Exception as e:
    print("Embedding retriever unavailable, falling back to TF-IDF. Reason:", e)
    tfidf = TfidfVectorizer(ngram_range=(1,2), min_df=2).fit(retrieval_corpus)
    mat = tfidf.transform(retrieval_corpus)
    retriever_info["mode"] = "tfidf"
    retriever_info["vocab_size"] = len(tfidf.vocabulary_)
    print("Retriever: TF-IDF cosine")

    def retrieve_context(query, top_k=3):
        qv = tfidf.transform([query])
        sims = cosine_similarity(qv, mat)[0]
        idx = np.argsort(-sims)[:top_k]
        return [retrieval_corpus[i] for i in idx], sims[idx].tolist()

# Quick sanity check evals (optional, can comment to save time)
evaluate(sample_n=60, use_beam=True,  use_rag=False)
evaluate(sample_n=60, use_beam=True,  use_rag=True,  k=3)


Batches:   0%|          | 0/256 [00:00<?, ?it/s]

Retriever: sentence-transformers + FAISS (IP)
Eval (beam=True, RAG=False, k=3) -> ROUGE-L: 0.3145  BLEU: 13.28
Eval (beam=True, RAG=True, k=3) -> ROUGE-L: 0.1993  BLEU: 6.31


In [15]:
# ---------------------------
# 9) Gradio App (Chat + Mode Switch + Sources)
# ---------------------------
APP_TITLE = " Medical QA Bot — Transformer + RAG"
APP_DESC  = ("Ask general medical questions. This system is trained on dataset— "
)

def answer_api(message, mode, top_k, show_sources):
    """
    mode: "Generator only" | "RAG (retrieve + generate)"
    top_k: int for RAG retrieval
    """
    if mode == "RAG (retrieve + generate)":
        ctx, scores = retrieve_context(message, top_k=int(top_k))
        ctx_text = " ".join(ctx)
        ans = beam_search(message, beam=3, context_prefix=ctx_text)
        if show_sources:
            srcs = "\n\nSources:\n" + "\n".join([f"- {c}" for c in ctx])
            return ans + srcs
        return ans
    else:
        return beam_search(message, beam=3)

with gr.Blocks(title=APP_TITLE) as demo:
    gr.Markdown(f"# {APP_TITLE}")
    gr.Markdown(APP_DESC)

    with gr.Row():
        mode = gr.Radio(choices=["Generator only", "RAG (retrieve + generate)"],
                        value="RAG (retrieve + generate)", label="Mode")
        top_k = gr.Slider(1, 10, value=3, step=1, label="RAG: top-k passages")
        show_sources = gr.Checkbox(value=True, label="Show retrieved sources")

    chat = gr.Chatbot(height=340)
    msg = gr.Textbox(label="Your medical question")
    clear = gr.Button("Clear")

    def user_submit(user_message, history, mode, top_k, show_sources):
        if not user_message or not user_message.strip():
            return gr.update(value=""), history
        history = history + [[user_message, None]]
        # compute answer
        ans = answer_api(user_message, mode, top_k, show_sources)
        history[-1][1] = ans
        return "", history

    msg.submit(user_submit, [msg, chat, mode, top_k, show_sources], [msg, chat])
    clear.click(lambda: ([], ""), outputs=[chat, msg])

demo.launch()

  chat = gr.Chatbot(height=340)


It looks like you are running Gradio on a hosted Jupyter notebook, which requires `share=True`. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://45c7ccc6fe78ccdd63.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [14]:
# ---------------------------
# 10) Reload checkpoint utility
# ---------------------------
def load_checkpoint(path=CKPT_PATH):
    ck = torch.load(path, map_location=device)
    cfg = ck["config"]
    model = Seq2SeqTransformer(
        vocab_size=cfg["vocab_size"], d_model=cfg["d_model"],
        nhead=cfg["nhead"], nlayers=cfg["nlayers"], dff=cfg["dff"],
        dropout=cfg["dropout"]
    ).to(device)
    model.load_state_dict(ck["state_dict"])
    return model

print("\nDone. App is running above. You can stop it and re-run demo.launch() after training more, if needed.")


Done. App is running above. You can stop it and re-run demo.launch() after training more, if needed.
