In [1]:
!pip install streamlit sentence-transformers datasets faiss-cpu langgraph langchain-core transformers accelerate bitsandbytes python-dotenv

Collecting streamlit
  Downloading streamlit-1.51.0-py3-none-any.whl.metadata (9.5 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting langgraph
  Downloading langgraph-1.0.2-py3-none-any.whl.metadata (7.4 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Collecting langgraph-checkpoint<4.0.0,>=2.1.0 (from langgraph)
  Downloading langgraph_checkpoint-3.0.1-py3-none-any.whl.metadata (4.7 kB)
Collecting langgraph-prebuilt<1.1.0,>=1.0.2 (from langgraph)
  Downloading langgraph_prebuilt-1.0.2-py3-none-any.whl.metadata (5.0 kB)
Collecting langgraph-sdk<0.3.0,>=0.2.2 (from langgraph)
  Downloading langgraph_sdk-0.2.9-py3-none-any.whl.metadata (1.5 kB)
Collecting ormsgpack>=1.12.0 (from langgraph-checkpoint<4.0.0,>=2.1.0->lang

In [None]:
!npm install -g localtunnel

[1G[0K‚†ô[1G[0K‚†π[1G[0K‚†∏[1G[0K‚†º[1G[0K‚†¥[1G[0K‚†¶[1G[0K‚†ß[1G[0K‚†á[1G[0K‚†è[1G[0K‚†ã[1G[0K‚†ô[1G[0K‚†π[1G[0K‚†∏[1G[0K‚†º[1G[0K‚†¥[1G[0K‚†¶[1G[0K‚†ß[1G[0K‚†á[1G[0K‚†è[1G[0K‚†ã[1G[0K‚†ô[1G[0K‚†π[1G[0K‚†∏[1G[0K‚†º[1G[0K‚†¥[1G[0K‚†¶[1G[0K‚†ß[1G[0K‚†á[1G[0K‚†è[1G[0K‚†ã[1G[0K‚†ô[1G[0K‚†π[1G[0K‚†∏[1G[0K‚†º[1G[0K‚†¥[1G[0K‚†¶[1G[0K‚†ß[1G[0K‚†á[1G[0K‚†è[1G[0K‚†ã[1G[0K‚†ô[1G[0K‚†π[1G[0K‚†∏[1G[0K‚†º[1G[0K
added 22 packages in 5s
[1G[0K‚†¥[1G[0K
[1G[0K‚†¥[1G[0K3 packages are looking for funding
[1G[0K‚†¥[1G[0K  run `npm fund` for details
[1G[0K‚†¥[1G[0K

In [None]:
%%writefile app.py
import streamlit as st
import json
import faiss
import numpy as np
import torch
from typing import TypedDict, Annotated, Sequence
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from dotenv import load_dotenv
import os

from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import operator

# Page configuration
st.set_page_config(
    page_title="MedLang - Women's Health Assistant",
    page_icon="ü§∞",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Custom CSS for better UI
st.markdown("""
    <style>
    .stApp {
        background-color: #f5f7fa;
    }
    .chat-message {
        padding: 1.5rem;
        border-radius: 0.5rem;
        margin-bottom: 1rem;
        display: flex;
        flex-direction: column;
    }
    .user-message {
        background-color: #e3f2fd;
        border-left: 4px solid #2196f3;
    }
    .assistant-message {
        background-color: #f1f8e9;
        border-left: 4px solid #8bc34a;
    }
    .reasoning-box {
        background-color: #fff3e0;
        border-left: 4px solid #ff9800;
        padding: 1rem;
        border-radius: 0.5rem;
        margin-bottom: 0.5rem;
        font-size: 0.9rem;
    }
    .model-badge {
        display: inline-block;
        padding: 0.3rem 0.8rem;
        border-radius: 1rem;
        font-size: 0.85rem;
        font-weight: bold;
        margin-bottom: 0.5rem;
        background-color: #ffebee;
        color: #c62828;
    }
    .stButton>button {
        width: 100%;
        background-color: #8bc34a;
        color: white;
    }
    .main-header {
        text-align: center;
        color: #2c3e50;
        padding: 2rem 0;
    }
    .context-box {
        background-color: #e1f5fe;
        border-left: 4px solid #0277bd;
        padding: 0.8rem;
        border-radius: 0.5rem;
        margin-bottom: 0.5rem;
        font-size: 0.85rem;
    }
    .stats-card {
        background-color: white;
        padding: 1rem;
        border-radius: 0.5rem;
        box-shadow: 0 2px 4px rgba(0,0,0,0.1);
        margin-bottom: 1rem;
    }
    </style>
""", unsafe_allow_html=True)


# --- Load environment variables ---
@st.cache_resource
def load_environment():
    load_dotenv()
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        st.error("‚ö†Ô∏è HF_TOKEN not found in .env file! Please add your HuggingFace access token.")
        st.info("Get your token from: https://huggingface.co/settings/tokens")
        st.stop()
    return hf_token

# --- Initialize embedding model ---
@st.cache_resource
def initialize_embedder():
    return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")


# --- Load Menstrual-LLaMA (Quantized Direct Load) ---
@st.cache_resource
def load_menstrual_llama(hf_token):
    """
    Load the Menstrual-LLaMA-8B model directly with 4-bit quantization
    for efficient use on Colab GPU.
    """
    try:
        # Define 4-bit quantization configuration
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )

        with st.spinner("üî¥ Loading Menstrual-LLaMA-8B model with 4-bit quantization..."):
            model_path = "proadhikary/Menstrual-LLaMA-8B"

            # 1. Load the tokenizer from the model path
            tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token)

            # 2. Load the model directly, applying 4-bit config
            model = AutoModelForCausalLM.from_pretrained(
                model_path,
                token=hf_token,
                quantization_config=bnb_config, # Apply BITSANDBYTES CONFIG
                device_map="auto",
            )

            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            model.eval()

        # st.success("‚úÖ Menstrual-LLaMA loaded successfully !") -----> commented this
        return model, tokenizer

    except Exception as e:
        st.error(f"‚ö†Ô∏è Could not load Menstrual-LLaMA: {str(e)}")
        st.info("This loading method requires your HF_TOKEN to have access to 'proadhikary/Menstrual-LLaMA-8B'. Please check your token and model permissions.")
        st.stop()


# --- Load dataset and build FAISS index ---
@st.cache_resource
def load_dataset_and_index(data_file, _embedder):
    try:
        qa_pairs = []
        with open(data_file, "r", encoding="utf-8") as f:
            for line in f:
                obj = json.loads(line)
                qa_pairs.append({"question": obj["question"], "answer": obj["answer"]})

        dataset = Dataset.from_list(qa_pairs)
        question_embeddings = _embedder.encode(dataset["question"], convert_to_numpy=True)

        dim = question_embeddings.shape[1]
        index = faiss.IndexFlatL2(dim)
        index.add(question_embeddings)

        return dataset, index
    except FileNotFoundError:
        st.error(f"‚ö†Ô∏è Dataset file '{data_file}' not found! Please ensure it's in the same directory.")
        st.stop()
    except Exception as e:
        st.error(f"‚ö†Ô∏è Error loading dataset: {str(e)}")
        st.stop()


# --- State Definition ---
class GraphState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    question: str
    retrieved_context: list
    reasoning: str
    answer: str


# --- Helper Functions ---
# Removed language detection function


def format_chat_history(messages, max_exchanges=3):
    """Format recent chat history for context"""
    if not messages:
        return ""

    history_str = "Previous conversation:\n"
    for msg in messages[-(max_exchanges*2):]:
        if isinstance(msg, HumanMessage):
            history_str += f"User: {msg.content}\n"
        elif isinstance(msg, AIMessage):
            history_str += f"Assistant: {msg.content}\n"
    return history_str + "\n"


# --- Node Functions ---
def retrieve_context(state: GraphState, dataset, index, embedder) -> GraphState:
    """Retrieve top 2 most relevant Q&A pairs from pregnancy dataset"""
    query = state["question"]

    # Search FAISS for top 2 similar questions
    query_emb = embedder.encode([query], convert_to_numpy=True)
    D, I = index.search(query_emb, 2)

    retrieved = [dataset[int(i)] for i in I[0]]

    return {
        "retrieved_context": retrieved
    }


def generate_reasoning_and_answer(state: GraphState, menstrual_llama, tokenizer) -> GraphState:
    """
    Single unified node: Generate both reasoning and answer using Menstrual-LLaMA
    """
    query = state["question"]
    context = state["retrieved_context"]
    chat_history = state["messages"][:-1]

    # Format retrieved pregnancy context
    context_str = ""
    if context:
        context_str = "\n\nRetrieved Pregnancy Knowledge Base (PREGNANCY ONLY - use ONLY if relevant):\n"
        for i, ctx in enumerate(context):
            context_str += f"\nReference {i+1}:\nQ: {ctx['question']}\nA: {ctx['answer']}\n"

    # Format chat history for multi-turn conversation
    history_str = format_chat_history(chat_history, max_exchanges=3)

    # Removed language detection and forceful language instruction

    # Construct system message
    system_message = """You are MedLang, an expert AI assistant for women's health, specializing in BOTH menstrual health and pregnancy. You MUST maintain conversational continuity and use the provided chat history to understand the context of follow-up questions.

YOUR KNOWLEDGE:
- You have been fine-tuned on 24,000+ expert-verified menstrual health Q&A pairs.
- You also have general pregnancy knowledge from your base training.
- You are capable of handling questions about periods, menstruation, PMS, PCOS, ovulation, fertility, pregnancy, conception, prenatal care, and more.
- You are multilingual and should respond in the same language as the user's query.
"""

    # Construct user message with all context
    user_message = f"""
{history_str}
CURRENT USER QUESTION: {query}

{context_str}

INSTRUCTIONS FOR ANSWERING:
1. REASONING FIRST (2-3 sentences):
   - **CRITICAL:** Analyze the **Previous conversation** and the **CURRENT USER QUESTION** together. Identify if this is a follow-up question (e.g., "What kind?") and explicitly state what it refers to (e.g., "What kind of songs for the baby").
   - Identify the primary topic: **Menstrual Health**, **Pregnancy/Fertility**, or **Irrelevant/General**.
   - If Menstrual Health, note that you will rely primarily on your internal knowledge.
   - If Pregnancy/Fertility, assess if the Retrieved Pregnancy Knowledge Base is relevant.
   - **CRITICAL:** If the query is about Menstrual Health (like delayed periods), explicitly state that the RAG context (which is pregnancy-only) is **IGNORED** for the answer.
   - Note the query language.

2. ANSWER SECOND (4-7 sentences):
   - **CRITICAL:** Do NOT give vague answers. Provide **SPECIFIC examples, causes, or types**.
   - **For information requiring specific detail (like causes of delayed periods or types of music): use a numbered or bulleted list in the answer.**
   - Use your extensive menstrual health knowledge as your PRIMARY source.
   - Use the Retrieved Pregnancy Knowledge Base **ONLY** if the query is clearly about a pregnancy topic.
   - For severe symptoms or emergencies, always include the standard medical disclaimer.
   - CRITICAL: Respond in the SAME language as the user's question.

FORMAT YOUR RESPONSE EXACTLY AS:
**REASONING:**
[Your 2-3 sentence pragmatic inference here]

**ANSWER:**
[Your detailed, specific, and structured response here, including lists where appropriate]"""

    try:
        # Use the chat template as specified in HuggingFace model card
        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message},
        ]

        # Apply chat template (LLaMA-3 format)
        input_ids = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            return_tensors="pt"
        ).to(menstrual_llama.device)

        # Terminators as specified in model card
        terminators = [
            tokenizer.eos_token_id,
            tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        # Disable specific CUDA optimizations as per model card
        if torch.cuda.is_available():
            torch.backends.cuda.enable_mem_efficient_sdp(False)
            torch.backends.cuda.enable_flash_sdp(False)

        # Generate with parameters from model card
        with torch.no_grad():
            outputs = menstrual_llama.generate(
                input_ids,
                pad_token_id=tokenizer.pad_token_id,
                max_new_tokens=400,
                eos_token_id=terminators,
                do_sample=True,
                temperature=0.6,  # As per model card
                top_p=0.9
            )

        # Extract only the generated response
        response = outputs[0][input_ids.shape[-1]:]
        response_text = tokenizer.decode(response, skip_special_tokens=True)

        # Parse reasoning and answer
        reasoning = ""
        answer = ""

        if "**REASONING:**" in response_text and "**ANSWER:**" in response_text:
            parts = response_text.split("**ANSWER:**")
            reasoning = parts[0].replace("**REASONING:**", "").strip()
            answer = parts[1].strip()
        elif "REASONING:" in response_text and "ANSWER:" in response_text:
            parts = response_text.split("ANSWER:")
            reasoning = parts[0].replace("REASONING:", "").strip()
            answer = parts[1].strip()
        else:
            # Fallback: treat entire response as answer
            answer = response_text
            reasoning = "Analyzing query and generating response based on training knowledge."

        return {
            "reasoning": reasoning,
            "answer": answer,
            "messages": [AIMessage(content=answer)]
        }

    except Exception as e:
        error_msg = f"I apologize, but I encountered an error processing your question. Error: {str(e)}"
        return {
            "reasoning": "Error occurred during processing",
            "answer": error_msg,
            "messages": [AIMessage(content=error_msg)]
        }


# --- Build the Graph ---
@st.cache_resource
def create_chatbot_graph(_embedder, _dataset, _index, _menstrual_llama, _tokenizer):
    """Create the LangGraph workflow with Menstrual-LLaMA"""
    workflow = StateGraph(GraphState)

    # Add nodes
    workflow.add_node(
        "retrieve",
        lambda state: retrieve_context(state, _dataset, _index, _embedder)
    )

    workflow.add_node(
        "reason_and_answer",
        lambda state: generate_reasoning_and_answer(state, _menstrual_llama, _tokenizer)
    )

    # Add edges: retrieve ‚Üí reason_and_answer ‚Üí END
    workflow.add_edge("retrieve", "reason_and_answer")
    workflow.add_edge("reason_and_answer", END)

    # Set entry point
    workflow.set_entry_point("retrieve")

    # Compile with memory for multi-turn conversations
    memory = MemorySaver()
    app = workflow.compile(checkpointer=memory)

    return app


# --- Initialize Session State ---
def initialize_session_state():
    if "messages" not in st.session_state:
        st.session_state.messages = []
    if "thread_id" not in st.session_state:
        import uuid
        st.session_state.thread_id = str(uuid.uuid4())
    if "show_reasoning" not in st.session_state:
        st.session_state.show_reasoning = True
    if "show_context" not in st.session_state:
        st.session_state.show_context = False
    if "query_count" not in st.session_state:
        st.session_state.query_count = 0


# --- Main App ---
def main():
    # Initialize
    initialize_session_state()
    hf_token = load_environment()
    embedder = initialize_embedder()

    # Load Menstrual-LLaMA with access token
    menstrual_llama, tokenizer = load_menstrual_llama(hf_token)

    # Sidebar
    with st.sidebar:
        st.markdown("### ü§∞ MedLang")
        st.markdown("*Women's Health Companion*")
        st.markdown("---")

        # Dataset file input
        data_file = st.text_input(
            "Pregnancy Dataset Path",
            value="merged_preg_dataset.jsonl",
            help="Path to your pregnancy Q&A JSONL file (1,378 pairs)"
        )

        # Load dataset and create graph
        if st.button("üîÑ Reload Dataset & Model"):
            st.cache_resource.clear()
            st.rerun()

        dataset, index = load_dataset_and_index(data_file, embedder)
        app = create_chatbot_graph(embedder, dataset, index, menstrual_llama, tokenizer)

        # st.success(f"‚úÖ {len(dataset)} pregnancy Q&A pairs loaded")
        # st.success("‚úÖ Menstrual-LLaMA-8B active (Quantized Load)")
        # st.info("‚ÑπÔ∏è Model trained on 24k+ menstrual Q&As")

        st.markdown("---")

        # Settings
        st.markdown("### ‚öôÔ∏è Display Settings")
        st.session_state.show_reasoning = st.checkbox(
            "Show Reasoning Process",
            value=st.session_state.show_reasoning,
            help="Display the model's internal reasoning"
        )
        st.session_state.show_context = st.checkbox(
            "Show Retrieved Context",
            value=st.session_state.show_context,
            help="Display pregnancy Q&As retrieved from RAG"
        )

        st.markdown("---")

        # # Statistics
        # st.markdown("### üìä Session Statistics")
        # st.markdown(f"""
        # <div class="stats-card">
        #     <strong>Total Queries:</strong> {st.session_state.query_count}<br>
        # </div>
        # """, unsafe_allow_html=True)

        # st.markdown("---")

        # Clear conversation
        if st.button("üóëÔ∏è Clear Conversation"):
            st.session_state.messages = []
            st.session_state.query_count = 0
            # Removed language_stats reset
            import uuid
            st.session_state.thread_id = str(uuid.uuid4())
            st.rerun()

        st.markdown("---")
        st.markdown("### ‚ÑπÔ∏è About MedLang")
        st.markdown("""
        **Model Architecture:**
        - üî¥ **Menstrual-LLaMA**: Fine-tuned on 24,000+ expert-verified menstrual health Q&A pairs
        - üìö **RAG Enhancement**: Retrieves relevant pregnancy Q&As when needed
        - üß† **Autonomous Decision Making**: Model intelligently decides when to use retrieved context

        **Capabilities:**
        - ‚úÖ Menstrual health (periods, PMS, PCOS, ovulation)
        - ‚úÖ Pregnancy (conception, prenatal care, symptoms)
        - ‚úÖ Fertility & reproductive health
        - ‚úÖ Multi-turn conversations with memory
        - ‚úÖ Multilingual queries supported

        **Features:**
        - Context-aware responses
        - Conversational memory via LangGraph
        - Reasoning transparency
        - Privacy-focused
        """)

        st.markdown("---")
        st.markdown("### üí° Example Questions")
        st.markdown("""
        - "What causes irregular periods?"
        - "Is cramping normal in early pregnancy?"
        - "How can I track my ovulation?"
        - "‡§Æ‡§æ‡§∏‡§ø‡§ï ‡§ß‡§∞‡•ç‡§Æ ‡§Æ‡•á‡§Ç ‡§¶‡•á‡§∞‡•Ä ‡§ï‡•á ‡§ï‡•ç‡§Ø‡§æ ‡§ï‡§æ‡§∞‡§£ ‡§π‡•à‡§Ç?"
        - "‡§ó‡§∞‡•ç‡§≠‡§æ‡§µ‡§∏‡•ç‡§•‡§æ ‡§ï‡•á ‡§∂‡•Å‡§∞‡•Å‡§Ü‡§§‡•Ä ‡§≤‡§ï‡•ç‡§∑‡§£ ‡§ï‡•ç‡§Ø‡§æ ‡§π‡•à‡§Ç?"
        """)

        st.markdown("---")
        st.markdown("‚ö†Ô∏è *This is not a substitute for professional medical advice. Always consult a healthcare provider for serious concerns.*")

    # Main chat interface
    st.markdown("<h1 class='main-header'>ü§∞ MedLang - Women's Health Companion</h1>",
                unsafe_allow_html=True)

    st.markdown("""
    <div style='text-align: center; color: #666; margin-bottom: 2rem;'>
    Ask questions about <b>pregnancy</b> or <b>menstrual health</b> üåè<br>
    <em>Powered by Menstrual-LLaMA-8B with RAG enhancement</em>
    </div>
    """, unsafe_allow_html=True)

    # Display chat messages
    for message in st.session_state.messages:
        if message["role"] == "user":
            with st.container():
                st.markdown(f"""
                    <div class="chat-message user-message">
                        <strong>üë§ You:</strong><br>
                        {message["content"]}
                    </div>
                """, unsafe_allow_html=True)
        else:
            with st.container():
                # Model badge
                st.markdown("""
                    <span class="model-badge">
                        üî¥ Menstrual-LLaMA-8B
                    </span>
                """, unsafe_allow_html=True)

                # Retrieved context (if enabled)
                if st.session_state.show_context and "context" in message and message["context"]:
                    st.markdown("<strong>üìö Retrieved Pregnancy Context (RAG):</strong>", unsafe_allow_html=True)
                    for i, ctx in enumerate(message["context"]):
                        with st.expander(f"Reference {i+1}: {ctx['question'][:80]}...", expanded=False):
                            st.markdown(f"**Q:** {ctx['question']}")
                            st.markdown(f"**A:** {ctx['answer'][:300]}...")

                # Reasoning (if enabled)
                if st.session_state.show_reasoning and "reasoning" in message and message["reasoning"]:
                    st.markdown(f"""
                        <div class="reasoning-box">
                            <strong>üß† Model Reasoning:</strong><br>
                            {message["reasoning"]}
                        </div>
                    """, unsafe_allow_html=True)

                # Answer
                st.markdown(f"""
                    <div class="chat-message assistant-message">
                        <strong>ü§ñ MedLang:</strong><br>
                        {message["content"]}
                    </div>
                """, unsafe_allow_html=True)

    # Chat input
    user_input = st.chat_input("Ask about pregnancy or menstrual health...")

    if user_input:
        # Update query count
        st.session_state.query_count += 1

        # Add user message
        st.session_state.messages.append({
            "role": "user",
            "content": user_input
        })

        # Display user message immediately
        with st.container():
            st.markdown(f"""
                <div class="chat-message user-message">
                    <strong>üë§ You:</strong><br>
                    {user_input}
                </div>
            """, unsafe_allow_html=True)

        # Generate response
        with st.spinner("ü§î Thinking..."):
            try:
                # Prepare initial state with conversation history
                history_messages = [msg for msg in st.session_state.messages[:-1]
                                   if isinstance(msg.get("content"), str)]

                langchain_history = []
                for msg in history_messages:
                    if msg["role"] == "user":
                        langchain_history.append(HumanMessage(content=msg["content"]))
                    else:
                        langchain_history.append(AIMessage(content=msg["content"]))

                initial_state = {
                    "messages": langchain_history + [HumanMessage(content=user_input)],
                    "question": user_input,
                    "retrieved_context": [],
                    "reasoning": "",
                    "answer": ""
                }

                config = {"configurable": {"thread_id": st.session_state.thread_id}}
                result = app.invoke(initial_state, config)

                # Add assistant message
                st.session_state.messages.append({
                    "role": "assistant",
                    "content": result["answer"],
                    "reasoning": result["reasoning"],
                    "context": result["retrieved_context"]
                })

                st.rerun()

            except Exception as e:
                st.error(f"‚ö†Ô∏è Error: {str(e)}")
                st.error("Please try rephrasing your question or check the model setup.")


if __name__ == "__main__":
    main()

Overwriting app.py


In [None]:
# # Run the Streamlit app in the background
# !nohup streamlit run medlang-app.py &

# # Wait a moment for Streamlit to initialize
# import time
# time.sleep(5)

# # Use localtunnel to expose port 8501
# print("Streamlit App is running. Click the link below:")
# !lt --port 8501 --subdomain medlang-app & curl https://loca.lt/mytunnelpassword

!streamlit run app.py &>/content/logs.txt & npx localtunnel --port 8501 & curl https://loca.lt/mytunnelpassword

35.227.165.190[1G[0K‚†ô[1G[0K‚†π[1G[0K‚†∏[1G[0K‚†º[1G[0K‚†¥[1G[0K‚†¶[1G[0K‚†ß[1G[0K‚†á[1G[0K‚†è[1G[0K‚†ã[1G[0K‚†ô[1G[0K‚†π[1G[0K‚†∏[1G[0K‚†º[1G[0K‚†¥[1G[0K‚†¶[1G[0K‚†ß[1G[0K‚†á[1G[0K‚†è[1G[0K‚†ã[1G[0K‚†ô[1G[0K‚†π[1G[0K‚†∏[1G[0K‚†º[1G[0Kyour url is: https://warm-bugs-walk.loca.lt


In [None]:
import json
import torch
import os
import numpy as np
import operator
import time
from typing import TypedDict, Annotated, Sequence
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from langgraph.graph import StateGraph, END
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv

# --- CONFIGURATION ---
TEST_SET_FILE = "test_set.jsonl" # This is the file you create
DATASET_FILE = "merged_preg_dataset.jsonl" # Your original RAG data
MODEL_PATH = "proadhikary/Menstrual-LLaMA-8B"

# --- Helper Functions from app.py (Modified slightly for standalone script) ---

# Note: Simplified version of load_menstrual_llama without Streamlit caching/spinner
def load_menstrual_llama_eval():
    load_dotenv()
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        print("ERROR: HF_TOKEN not found in .env file.")
        return None, None

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    print("Loading Menstrual-LLaMA-8B with 4-bit quantization...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        token=hf_token,
        quantization_config=bnb_config,
        device_map="auto",
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.eval()
    print("Model loaded successfully.")
    return model, tokenizer

# Helper function to compute semantic similarity
def calculate_semantic_similarity(generated_answer, ground_truth, embedder):
    """Calculates cosine similarity between the embeddings of the two texts."""
    if not generated_answer or not ground_truth:
        return 0.0

    embeddings = embedder.encode(
        [generated_answer, ground_truth],
        convert_to_numpy=True
    )
    # Cosine similarity between the two vectors
    return cosine_similarity(
        embeddings[0].reshape(1, -1),
        embeddings[1].reshape(1, -1)
    )[0][0]

# --- LangGraph Nodes (simplified, using logic from app.py) ---

class GraphState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    question: str
    retrieved_context: list
    reasoning: str
    answer: str

def retrieve_context(state, dataset, index, embedder):
    query = state["question"]
    query_emb = embedder.encode([query], convert_to_numpy=True)
    D, I = index.search(query_emb, 2)
    retrieved = [dataset[int(i)] for i in I[0]]
    return {"retrieved_context": retrieved}

def generate_reasoning_and_answer(state, menstrual_llama, tokenizer):
    """
    Evaluation version of the generation node with simplified single-turn prompt
    that forces specificity and RAG decision making.
    """
    query = state["question"]
    context = state["retrieved_context"]

    # 1. Format retrieved pregnancy context
    context_str = ""
    if context:
        context_str = "\n\nRetrieved Pregnancy Knowledge Base (PREGNANCY ONLY - use ONLY if relevant):\n"
        for i, ctx in enumerate(context):
            context_str += f"\nReference {i+1}:\nQ: {ctx['question']}\nA: {ctx['answer']}\n"

    # 2. Construct System Message (Simplified persona)
    system_message = """You are MedLang, an expert AI assistant for women's health, specializing in BOTH menstrual health and pregnancy. Your goal is to provide concise, structured, and factual medical information.

YOUR KNOWLEDGE:
- You have been fine-tuned on 24,000+ expert-verified menstrual health Q&A pairs.
- You have general pregnancy knowledge.
- You are capable of handling questions about periods, PCOS, pregnancy, and more.
"""

    # 3. Construct User Message (Focus on RAG/Specificity Instructions)
    user_message = f"""CURRENT USER QUESTION: {query}

{context_str}

INSTRUCTIONS FOR ANSWERING:
1. REASONING FIRST (2-3 sentences):
   - Identify the primary topic: **Menstrual Health** or **Pregnancy/Fertility**.
   - If Menstrual Health, explicitly state that the RAG context (which is pregnancy-only) is **IGNORED** for the answer.
   - If Pregnancy/Fertility, assess if the Retrieved Pregnancy Knowledge Base is relevant and state whether it will be used.

2. ANSWER SECOND (4-7 sentences):
   - **CRITICAL:** Do NOT give vague answers like "various reasons" or "various songs." Provide **SPECIFIC examples, causes, or types**.
   - **For information requiring specific detail, use a numbered or bulleted list in the answer.**
   - Use your extensive menstrual health knowledge as your PRIMARY source.
   - Use the Retrieved Pregnancy Knowledge Base **ONLY** if the query is clearly about a pregnancy topic.
   - For severe symptoms or emergencies, always include the standard medical disclaimer: "Please consult a healthcare provider immediately for personalized medical advice."

FORMAT YOUR RESPONSE EXACTLY AS:
**REASONING:**
[Your 2-3 sentence pragmatic inference here]

**ANSWER:**
[Your detailed, specific, and structured response here, including lists where appropriate]"""

    # 4. Invoke the model
    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(menstrual_llama.device)

    terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]

    with torch.no_grad():
        outputs = menstrual_llama.generate(
            input_ids,
            pad_token_id=tokenizer.pad_token_id,
            max_new_tokens=400,
            eos_token_id=terminators,
            do_sample=False, # Use deterministic generation for better eval consistency
            temperature=0.6,
            top_p=0.9
        )

    response = outputs[0][input_ids.shape[-1]:]
    response_text = tokenizer.decode(response, skip_special_tokens=True).strip()

    # Parse reasoning and answer from structured output
    reasoning = "N/A (Parsing Error)"
    answer = response_text

    if "**REASONING:**" in response_text and "**ANSWER:**" in response_text:
        parts = response_text.split("**ANSWER:**")
        reasoning = parts[0].replace("**REASONING:**", "").strip()
        answer = parts[1].strip()

    return {
        "reasoning": reasoning,
        "answer": answer,
        "messages": [AIMessage(content=answer)]
    }

def create_chatbot_graph(embedder, dataset, index, menstrual_llama, tokenizer):
    workflow = StateGraph(GraphState)
    workflow.add_node("retrieve", lambda state: retrieve_context(state, dataset, index, embedder))
    workflow.add_node("reason_and_answer", lambda state: generate_reasoning_and_answer(state, menstrual_llama, tokenizer))
    workflow.add_edge("retrieve", "reason_and_answer")
    workflow.add_edge("reason_and_answer", END)
    workflow.set_entry_point("retrieve")
    return workflow.compile()

# --- MAIN EVALUATION FUNCTION ---

def run_evaluation():
    # 1. Load Data
    print("--- 1. Loading Models and Data ---")
    embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
    menstrual_llama, tokenizer = load_menstrual_llama_eval()

    if menstrual_llama is None:
        return

    # Load original dataset for RAG Indexing
    try:
        qa_pairs_rag = []
        with open(DATASET_FILE, "r", encoding="utf-8") as f:
            for line in f:
                qa_pairs_rag.append(json.loads(line))

        # This dataset is only used for RAG retrieval during evaluation
        rag_dataset = Dataset.from_list(qa_pairs_rag)
        question_embeddings = embedder.encode(rag_dataset["question"], convert_to_numpy=True)

        import faiss
        dim = question_embeddings.shape[1]
        rag_index = faiss.IndexFlatL2(dim)
        rag_index.add(question_embeddings)
        print(f"RAG Index loaded with {len(rag_dataset)} documents.")

    except FileNotFoundError:
        print(f"ERROR: RAG dataset file '{DATASET_FILE}' not found. Cannot proceed.")
        return
    except Exception as e:
        print(f"ERROR loading RAG index: {e}")
        return

    # Load Golden Test Set
    try:
        with open(TEST_SET_FILE, "r", encoding="utf-8") as f:
            test_cases = [json.loads(line) for line in f]
        print(f"Golden Test Set loaded with {len(test_cases)} queries.")
    except FileNotFoundError:
        print(f"ERROR: Test set file '{TEST_SET_FILE}' not found. Please create it first.")
        return

    # Initialize LangGraph
    app = create_chatbot_graph(embedder, rag_dataset, rag_index, menstrual_llama, tokenizer)

    print("\n--- 2. Running Evaluation ---")

    # Metrics tracking
    total_queries = len(test_cases)
    total_similarity_score = 0
    retrieval_success_count = 0
    start_time = time.time()

    for i, case in enumerate(test_cases):
        print(f"Query {i+1}/{total_queries}: {case['question'][:50]}...")

        # 2a. Run the LangGraph system
        initial_state = {
            "messages": [HumanMessage(content=case["question"])],
            "question": case["question"],
            "retrieved_context": [],
            "reasoning": "",
            "answer": ""
        }

        # Note: We use a static thread_id for this single-pass evaluation
        result = app.invoke(initial_state, {"configurable": {"thread_id": "eval_thread"}})

        # 2b. Calculate Semantic Similarity
        generated_answer = result['answer']
        similarity = calculate_semantic_similarity(
            generated_answer,
            case['ground_truth_answer'],
            embedder
        )
        total_similarity_score += similarity

        # 2c. Calculate Retrieval Accuracy@2 (Only for Pregnancy queries)
        if case['is_pregnancy'] and case.get('ground_truth_question'):
            # Check if the expected question text is present in the top 2 retrieved Q's
            retrieved_questions = [r['question'] for r in result['retrieved_context']]
            if case['ground_truth_question'].strip() in [q.strip() for q in retrieved_questions]:
                retrieval_success_count += 1

    end_time = time.time()

    # 3. Calculate Final Scores
    avg_similarity = total_similarity_score / total_queries
    # Retrieval accuracy only counts the pregnancy (RAG-reliant) queries
    num_rag_queries = sum(1 for c in test_cases if c['is_pregnancy'])

    if num_rag_queries > 0:
        retrieval_accuracy = retrieval_success_count / num_rag_queries
    else:
        retrieval_accuracy = 0.0

    print("\n" + "="*50)
    print("         MEDLANG EVALUATION SUMMARY")
    print("="*50)
    print(f"Total Queries Tested: {total_queries}")
    print(f"Total Runtime: {end_time - start_time:.2f} seconds")
    print(f"Average Inference Time: {(end_time - start_time) / total_queries:.2f} seconds/query")
    print("\n--- METRICS ---")
    print(f"1. Semantic Similarity Score (Avg.): {avg_similarity:.4f}")
    print(f"2. Retrieval Accuracy@2 (RAG Queries): {retrieval_accuracy:.4f} ({retrieval_success_count}/{num_rag_queries} correct)")
    print("="*50)

if __name__ == "__main__":
    run_evaluation()

--- 1. Loading Models and Data ---


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Loading Menstrual-LLaMA-8B with 4-bit quantization...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded successfully.


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


RAG Index loaded with 942 documents.
Golden Test Set loaded with 120 queries.

--- 2. Running Evaluation ---
Query 1/120: What is the scientific term for "chumps" or "perio...
Query 2/120: What is ‚Äúmenarche‚Äù?...
Query 3/120: At what age do menses usually begin?...
Query 4/120: What happens during menses?...
Query 5/120: For how long does the bleeding last?...
Query 6/120: Is it very painful?...
Query 7/120: What is a sanitary pad?...
Query 8/120: Is it okay to use a cloth instead of a sanitary pa...
Query 9/120: Are sanitary pads too costly?...
Query 10/120: Can sanitary pads be availed free of cost?...
Query 11/120: What needs to be done after using a sanitary pad?...
Query 12/120: What is the proper way to dispose of sanitary pads...
Query 13/120: Can sanitary pads be disposed of in a commode?...
Query 14/120: How many sanitary pads are required per month?...
Query 15/120: How often should sanitary pads be changed during t...
Query 16/120: What will happen if I use the same sanit

llama 3 8b evaluation on 120 Q/A:

In [3]:
import json
import torch
import os
import time
from typing import Sequence
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv

# --- CONFIGURATION ---
TEST_SET_FILE = "test_set.jsonl"
# Using a widely available 4-bit quantized LLaMA-3-8B-Instruct model for baseline comparison
BASELINE_MODEL_PATH = "unsloth/llama-3-8b-Instruct-bnb-4bit"

# --- Helper Functions ---

def load_model_eval(model_path, hf_token):
    """Loads the LLaMA-3-8B-Instruct model with 4-bit quantization."""
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    print(f"Loading Baseline Model: {model_path} with 4-bit quantization...")
    try:
        # Note: AutoTokenizer from a base LLaMA-3 path often works well
        tokenizer = AutoTokenizer.from_pretrained(model_path, token=hf_token)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            token=hf_token,
            quantization_config=bnb_config,
            device_map="auto",
        )
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        model.eval()
        print("Baseline Model loaded successfully.")
        return model, tokenizer
    except Exception as e:
        print(f"ERROR loading model {model_path}: {e}")
        return None, None

def calculate_semantic_similarity(generated_answer, ground_truth, embedder):
    """Calculates cosine similarity between the embeddings of the two texts."""
    if not generated_answer or not ground_truth:
        return 0.0

    embeddings = embedder.encode(
        [generated_answer, ground_truth],
        convert_to_numpy=True
    )
    return cosine_similarity(
        embeddings[0].reshape(1, -1),
        embeddings[1].reshape(1, -1)
    )[0][0]

def baseline_generate(query, model, tokenizer):
    """
    Generates a response using the Base LLaMA-3-Instruct model in a zero-shot manner.
    NO RAG, simple instruction prompt focusing on detailed health answers.
    """
    system_message = "You are a helpful and expert AI assistant for women's health, providing factual, sensitive, and detailed medical information."

    # Prompt instructs for specificity and structured output, mirroring MedLang's requirements
    # to make the output quality comparable.
    user_message = f"""User Question: {query}

    Provide a detailed answer of 5-7 sentences. For lists of causes, symptoms, or examples, please use a numbered or bulleted list. Maintain a sensitive and factual tone."""

    messages = [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
    ]

    # Model Inference (Deterministic generation for better evaluation consistency)
    input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
    terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            pad_token_id=tokenizer.pad_token_id,
            max_new_tokens=400,
            eos_token_id=terminators,
            do_sample=False, # Use deterministic generation
            temperature=0.6,
            top_p=0.9
        )

    response_text = tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True).strip()

    return response_text

# --- MAIN EVALUATION FUNCTION ---

def run_baseline_evaluation():
    load_dotenv()
    hf_token = os.getenv("HF_TOKEN")
    if not hf_token:
        print("ERROR: HF_TOKEN not found in .env file.")
        return

    # 1. Load Data & Embedder
    print("--- 1. Loading Models and Data ---")
    embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

    # Load Golden Test Set
    try:
        with open(TEST_SET_FILE, "r", encoding="utf-8") as f:
            test_cases = [json.loads(line) for line in f]
        print(f"Golden Test Set loaded with {len(test_cases)} queries.")
    except FileNotFoundError:
        print(f"ERROR: Test set file '{TEST_SET_FILE}' not found. Please create it first.")
        return

    # Load Baseline Model
    baseline_model, baseline_tokenizer = load_model_eval(BASELINE_MODEL_PATH, hf_token)

    if not baseline_model:
        return

    print("\n--- 2. Running Baseline Evaluation (LLaMA-3 Zero-Shot) ---")

    # Metrics tracking
    total_queries = len(test_cases)
    baseline_similarity_score = 0

    start_time = time.time()

    for i, case in enumerate(test_cases):
        print(f"Query {i+1}/{total_queries}: {case['question'][:50]}...")

        query = case["question"]
        gt_answer = case['ground_truth_answer']

        # 2a. LLaMA-3 Baseline Generation (Zero-Shot)
        generated_answer = baseline_generate(query, baseline_model, baseline_tokenizer)

        # 2b. Calculate Semantic Similarity
        similarity = calculate_semantic_similarity(generated_answer, gt_answer, embedder)
        baseline_similarity_score += similarity

    end_time = time.time()

    # 3. Calculate Final Scores
    avg_baseline_similarity = baseline_similarity_score / total_queries

    print("\n" + "="*50)
    print("      LLaMA-3 BASELINE EVALUATION SUMMARY")
    print("="*50)
    print(f"Total Queries Tested: {total_queries}")
    print(f"Total Runtime: {end_time - start_time:.2f} seconds")
    print(f"Average Inference Time: {(end_time - start_time) / total_queries:.2f} seconds/query")
    print("\n--- METRICS ---")
    print(f"1. LLaMA-3 Baseline (Zero-Shot) Avg. Semantic Similarity: {avg_baseline_similarity:.4f}")
    print("="*50)

if __name__ == "__main__":
    run_baseline_evaluation()

--- 1. Loading Models and Data ---
Golden Test Set loaded with 120 queries.
Loading Baseline Model: unsloth/llama-3-8b-Instruct-bnb-4bit with 4-bit quantization...


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/345 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]



model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/220 [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Baseline Model loaded successfully.

--- 2. Running Baseline Evaluation (LLaMA-3 Zero-Shot) ---
Query 1/120: What is the scientific term for "chumps" or "perio...
Query 2/120: What is ‚Äúmenarche‚Äù?...
Query 3/120: At what age do menses usually begin?...
Query 4/120: What happens during menses?...
Query 5/120: For how long does the bleeding last?...
Query 6/120: Is it very painful?...
Query 7/120: What is a sanitary pad?...
Query 8/120: Is it okay to use a cloth instead of a sanitary pa...
Query 9/120: Are sanitary pads too costly?...
Query 10/120: Can sanitary pads be availed free of cost?...
Query 11/120: What needs to be done after using a sanitary pad?...
Query 12/120: What is the proper way to dispose of sanitary pads...
Query 13/120: Can sanitary pads be disposed of in a commode?...
Query 14/120: How many sanitary pads are required per month?...
Query 15/120: How often should sanitary pads be changed during t...
Query 16/120: What will happen if I use the same sanitary pad fo...