In [None]:
# chainlit run app.py -w
"""
PrescriptionIQ: A pharmacist assistant chatbot.
This app uses Chainlit, LangChain, and Qdrant to answer pharmacy-related queries.
It fetches related PubMed papers, suggests follow-up questions,
and displays source details in a clickable, expandable format.
"""

# -------------------------
# Standard library imports
# -------------------------
import asyncio
import io
import json
import os
import re
import requests
import zipfile

# Data handling
import pandas as pd

# Environment variables
from dotenv import load_dotenv

# Typing for function signatures
from typing import Any, List, Optional

# Bioinformatics
from Bio import Entrez, Medline

# -------------------------
# Chainlit & LangChain imports
# -------------------------
import chainlit as cl
from chainlit.types import AskFileResponse

from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain_community.chat_models import ChatOpenAI
from langchain.docstore.document import Document
from langchain.evaluation import StringEvaluator
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.smith import RunEvalConfig, run_on_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
from langchain_openai import OpenAI, OpenAIEmbeddings

# -------------------------
# Vector storage & Document Loading
# -------------------------
from langchain_community.document_loaders import DataFrameLoader
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient, AsyncQdrantClient

# -------------------------
# Custom Evaluations
# -------------------------
from custom_eval import PharmAssistEvaluator, HarmfulnessEvaluator, AIDetectionEvaluator

# -------------------------
# LangSmith Client
# -------------------------
from langsmith import Client
langsmith_client = Client()

# -------------------------
# Load Environment Variables
# -------------------------
load_dotenv()

# -------------------------
# System Prompt & Message Setup
# -------------------------
system_template = """
You are an AI assistant for pharmacists and pharmacy students.
Use the context provided to answer the user's question.

If you don't have enough information, say so—do not fabricate an answer.

Always include a **SOURCES** section at the end of your response, referencing the documents used.

Example response format:
**Answer:**
<your answer here>

**SOURCES:**
Source 1, Source 2, etc.

Begin!
----------------
{summaries}
"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}

# Global Qdrant vector store reference
qdrant_vectorstore = None

# -------------------------
# Function: Search Related Papers on PubMed
# -------------------------
async def search_related_papers(query, max_results=3):
    """
    Search PubMed for papers related to the provided query and return a list of formatted strings.
    """
    try:
        # If you have an email, set it here:
        Entrez.email = os.environ.get("ENTREZ_EMAIL")
        handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results)
        record = Entrez.read(handle)
        handle.close()

        id_list = record["IdList"]
        if not id_list:
            return ["No directly related papers found. Try broadening your search query."]

        handle = Entrez.efetch(db="pubmed", id=id_list, rettype="medline", retmode="text")
        records = Medline.parse(handle)

        related_papers = []
        for record in records:
            title = record.get("TI", "")
            authors = ", ".join(record.get("AU", []))
            citation = f"{authors}. {title}. {record.get('SO', '')}"
            url = f"https://pubmed.ncbi.nlm.nih.gov/{record['PMID']}/"
            related_papers.append(f"[{citation}]({url})")

        if not related_papers:
            return ["No directly related papers found. Try broadening your search query."]
        return related_papers
    except Exception as e:
        print(f"Error occurred while searching for related papers: {e}")
        return ["An error occurred while searching for related papers. Please try again later."]

# -------------------------
# Function: Generate Related Questions
# -------------------------
async def generate_related_questions(retrieved_results, num_questions=2, max_tokens=50):
    """
    Generate related questions based on retrieved document context.
    """
    llm = OpenAI(temperature=0.7)
    prompt_template = PromptTemplate(
        input_variables=["context"],
        template="Given the following context, generate {num_questions} related questions:\n\nContext: {context}\n\nQuestions:",
    )
    chain = LLMChain(llm=llm, prompt=prompt_template)

    context = " ".join([doc.page_content for doc in retrieved_results])
    generated_questions = chain.run(
        context=context, num_questions=num_questions, max_tokens=max_tokens
    )

    # Remove numbering from the generated questions
    related_questions = [
        question.split(". ", 1)[-1]
        for question in generated_questions.split("\n")
        if question.strip()
    ]
    return related_questions

# -------------------------
# Function: Generate Answer
# -------------------------
async def generate_answer(query):
    """
    Generate an answer using a conversational retrieval chain.
    Return:
      (formatted_answer, text_elements, related_question_actions, related_papers, source_actions, original_query)
    """
    message_history = ChatMessageHistory()
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        chat_memory=message_history,
        return_messages=True,
    )

    chain = ConversationalRetrievalChain.from_llm(
        ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True),
        chain_type="stuff",
        retriever=qdrant_vectorstore.as_retriever(),
        memory=memory,
        return_source_documents=True,
    )

    try:
        cb = cl.AsyncLangchainCallbackHandler()
        feedback_callback = EvaluatorCallbackHandler(
            evaluators=[
                PharmAssistEvaluator(),
                HarmfulnessEvaluator(),
                AIDetectionEvaluator(),
            ]
        )

        # 1) Call the chain
        res = await chain.acall(query, callbacks=[cb, feedback_callback])
        answer = res["answer"]
        source_documents = res["source_documents"]

        # 2) Remove any triple backticks to avoid code blocks & "copy" button
        answer = answer.replace("```", "")

        # 3) Check if the LLM says "I don't know..."
        if answer.lower().startswith("i don't know") or answer.lower().startswith("i don't have enough information"):
            return answer, [], [], [], [], query

        # 4) Format the answer as Markdown
        formatted_answer = f"**Answer:**\n{answer}"

        # 5) Create source buttons for each source doc
        source_actions = []
        if source_documents:
            for i, doc in enumerate(source_documents):
                source_actions.append(
                    cl.Action(
                        name="show_source",
                        label=f"Source {i+1}",
                        payload={"source_content": doc.page_content}
                    )
                )

        # 6) Generate related question buttons
        related_questions = await generate_related_questions(source_documents)
        related_question_actions = [
            cl.Action(
                name="related_question",
                label=q.strip(),
                payload={"question": q.strip()},
            )
            for q in related_questions if q.strip()
        ]

        # 7) Get related PubMed papers (markdown)
        related_papers = await search_related_papers(query)

        return formatted_answer, [], related_question_actions, related_papers, source_actions, query

    except Exception as e:
        print(f"Error occurred: {e}")
        return (
            "An error occurred while processing your request. Please try again later.",
            [],
            [],
            [],
            [],
            query,
        )

# -------------------------
# Action Callback: Show Source Details
# -------------------------
@cl.action_callback("show_source")
async def on_show_source(action: cl.Action):
    """
    When a source button is clicked, display the full source details.
    """
    source_content = action.payload["source_content"]
    # Also remove triple backticks if present in the source
    source_content = source_content.replace("```", "")
    await cl.Message(
        content=f"**Source Details:**\n\n{source_content}",
        author="PrescriptionIQ"
    ).send()

# -------------------------
# Action Callback: Related Question Selection
# -------------------------
@cl.action_callback("related_question")
async def on_related_question_selected(action: cl.Action):
    """
    Handle a related question selection.
    """
    question = action.payload["question"]
    await cl.Message(content=question, author="User").send()

    # Generate the new answer
    answer, text_elements, related_question_actions, related_papers, source_actions, original_query = await generate_answer(question)

    await cl.Message(content=answer, author="PrescriptionIQ").send()

    # If we have more related Qs, show them
    if related_question_actions:
        await cl.Message(
            content="**Related Questions:**",
            actions=related_question_actions,
            author="PrescriptionIQ"
        ).send()

    # Show PubMed references if any
    if related_papers:
        papers_md = "**Related Papers from PubMed:**\n" + "\n".join(f"- {paper}" for paper in related_papers)
        await cl.Message(content=papers_md, author="PrescriptionIQ").send()

    # Show the source expansion buttons
    if source_actions:
        await cl.Message(
            content="**Sources:** Click to expand.",
            actions=source_actions,
            author="PrescriptionIQ"
        ).send()

# -------------------------
# Action Callback: Ask Question Selection
# -------------------------
@cl.action_callback("ask_question")
async def on_question_selected(action: cl.Action):
    """
    Respond to a selected question from the suggestions.
    """
    question = action.payload["question"]
    await cl.Message(content=question, author="User").send()

    answer, text_elements, related_question_actions, related_papers, source_actions, original_query = await generate_answer(question)

    await cl.Message(content=answer, author="PrescriptionIQ").send()

    if related_question_actions:
        await cl.Message(
            content="**Related Questions:**",
            actions=related_question_actions,
            author="PrescriptionIQ"
        ).send()

    if related_papers:
        papers_md = "**Related Papers from PubMed:**\n" + "\n".join(f"- {paper}" for paper in related_papers)
        await cl.Message(content=papers_md, author="PrescriptionIQ").send()

    if source_actions:
        await cl.Message(
            content="**Sources:** Click to expand.",
            actions=source_actions,
            author="PrescriptionIQ"
        ).send()

# -------------------------
# on_chat_start
# -------------------------
@cl.on_chat_start
async def on_chat_start():
    """
    Initialize the chatbot environment, load Qdrant data, and present initial suggestions.
    """
    global qdrant_vectorstore

    await cl.Message(content="**Loading PrescriptionIQ bot...**").send()
    await asyncio.sleep(2)

    if qdrant_vectorstore is None:
        embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
        QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
        QDRANT_CLUSTER_URL = os.environ.get("QDRANT_CLUSTER_URL")
        qdrant_client = AsyncQdrantClient(url=QDRANT_CLUSTER_URL, api_key=QDRANT_API_KEY, timeout=60)
        response = await qdrant_client.get_collections()
        collection_names = [collection.name for collection in response.collections]

        if "fda_drugs" not in collection_names:
            print("Collection 'fda_drugs' is not present.")
            # Download data
            url = "https://download.open.fda.gov/drug/label/drug-label-0001-of-0012.json.zip"
            resp = requests.get(url)
            zip_file = zipfile.ZipFile(io.BytesIO(resp.content))
            json_file = zip_file.open(zip_file.namelist()[0])
            data = json.load(json_file)

            df = pd.json_normalize(data["results"])

            metadata_fields = [
                "openfda.brand_name",
                "openfda.generic_name",
                "openfda.manufacturer_name",
                "openfda.product_type",
                "openfda.route",
                "openfda.substance_name",
                "openfda.rxcui",
                "openfda.spl_id",
                "openfda.package_ndc",
            ]
            text_fields = [
                "description",
                "indications_and_usage",
                "contraindications",
                "warnings",
                "adverse_reactions",
                "dosage_and_administration",
            ]

            df[text_fields] = df[text_fields].fillna("")
            df["content"] = df[text_fields].apply(lambda x: " ".join(x.astype(str)), axis=1)

            loader = DataFrameLoader(df, page_content_column="content")
            drug_docs = loader.load()

            for doc, row in zip(drug_docs, df.to_dict(orient="records")):
                md = {}
                for f in metadata_fields:
                    val = row.get(f)
                    if isinstance(val, list):
                        val = ", ".join(str(v) for v in val if pd.notna(v))
                    elif pd.isna(val):
                        val = "Not Available"
                    md[f] = val
                doc.metadata = md

            text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
            split_drug_docs = text_splitter.split_documents(drug_docs)

            qdrant_vectorstore = await cl.make_async(Qdrant.from_documents)(
                split_drug_docs,
                embedding_model,
                url=QDRANT_CLUSTER_URL,
                api_key=QDRANT_API_KEY,
                collection_name="fda_drugs",
            )
        else:
            print("Collection 'fda_drugs' is present.")
            qdrant_vectorstore = await cl.make_async(Qdrant.construct_instance)(
                texts=[""],
                embedding=embedding_model,
                url=QDRANT_CLUSTER_URL,
                api_key=QDRANT_API_KEY,
                collection_name="fda_drugs",
            )

    # Suggested questions
    potential_questions = [
        "What should I be careful of when taking Metformin?",
        "What are the contraindications of Aspirin?",
        "Are there low-cost alternatives to branded Aspirin available over-the-counter?",
        "What precautions should I take if I'm pregnant or nursing while on Lipitor?",
        "Should Lipitor be taken at a specific time of day, and does it need to be taken with food?",
        "What is the recommended dose of Aspirin?",
        "Can older people take beta blockers?",
        "How do beta blockers work?",
        "Can beta blockers be used for anxiety?",
        "I am taking Aspirin, is it ok to take Glipizide?",
        "Explain in simple terms how Metformin works?",
    ]
    await cl.Message(
        content="**Welcome to PrescriptionIQ!** Here are some potential questions you can ask:",
        actions=[
            cl.Action(name="ask_question", label=q, payload={"question": q})
            for q in potential_questions
        ],
    ).send()
    cl.user_session.set("potential_questions_shown", True)

# -------------------------
# On Message: Free Text
# -------------------------
@cl.on_message
async def main(message):
    """
    Process free text user input and generate a formatted answer.
    """
    query = message.content.strip()
    if not query:
        await cl.Message(content="Please enter a valid question.", author="PrescriptionIQ").send()
        return

    try:
        answer, text_elements, related_question_actions, related_papers, source_actions, original_query = await generate_answer(query)
        if not answer:
            answer = "Sorry, I couldn't generate an answer. Please try rephrasing your question."

        # Combine everything, remove any triple backticks if leftover
        answer = answer.replace("```", "")
        final_msg = f"{answer}\n\n---\n*Your question: {original_query}*"

        await cl.Message(content=final_msg, elements=text_elements, author="PrescriptionIQ").send()

        if related_question_actions:
            await cl.Message(
                content="**Related Questions:**",
                actions=related_question_actions,
                author="PrescriptionIQ"
            ).send()

        if related_papers:
            papers_md = "**Related Papers from PubMed:**\n" + "\n".join(f"- {p}" for p in related_papers)
            await cl.Message(content=papers_md, author="PrescriptionIQ").send()

        if source_actions:
            await cl.Message(
                content="**Sources:** Click to expand.",
                actions=source_actions,
                author="PrescriptionIQ"
            ).send()

    except Exception as e:
        print(f"Error occurred: {e}")
        error_msg = "An error occurred while processing your request. Please try again later."
        await cl.Message(content=error_msg, author="PrescriptionIQ").send()


Number of items: 20000 in the FDA dataset
Number of Document objects: 19472


ModuleNotFoundError: No module named 'openai.error'