In [None]:
# Cell 1: Imports and Environment Setup

import asyncio
import io
import json
import os
import re
import requests
import zipfile

import pandas as pd
from dotenv import load_dotenv
from typing import Any, List, Optional

from Bio import Entrez, Medline

import chainlit as cl
from chainlit.types import AskFileResponse

from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain_community.chat_models import ChatOpenAI
from langchain.docstore.document import Document
from langchain.evaluation import StringEvaluator
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.memory import ConversationBufferMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.smith import RunEvalConfig, run_on_dataset
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
from langchain_openai import OpenAI, OpenAIEmbeddings

from langchain_community.document_loaders import DataFrameLoader
from langchain_community.vectorstores import Qdrant
from qdrant_client import QdrantClient, AsyncQdrantClient

from custom_eval import PharmAssistEvaluator, HarmfulnessEvaluator, AIDetectionEvaluator

from langsmith import Client
langsmith_client = Client()

# Load environment variables from .env file
load_dotenv()


In [None]:
# Cell 2: System Prompt and Global Variables

system_template = """
You are an AI assistant for pharmacists and pharmacy students.
Use the context provided to answer the user's question.

If you don't have enough information, say soâ€”do not fabricate an answer.

Always include a **SOURCES** section at the end of your response, referencing the documents used.

Example response format:
**Answer:**
<your answer here>

**SOURCES:**
Source 1, Source 2, etc.

Begin!
----------------
{summaries}
"""

messages = [
    SystemMessagePromptTemplate.from_template(system_template),
    HumanMessagePromptTemplate.from_template("{question}"),
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}

# Global variable for Qdrant vector store
qdrant_vectorstore = None


In [11]:
# Cell 3: Functions for PubMed Search and Related Questions

async def search_related_papers(query, max_results=3):
    """
    Search PubMed for papers related to the provided query and return a list of formatted references.
    """
    try:
        Entrez.email = os.environ.get("ENTREZ_EMAIL")  # Set your email if needed
        handle = Entrez.esearch(db="pubmed", term=query, retmax=max_results)
        record = Entrez.read(handle)
        handle.close()
        id_list = record["IdList"]
        if not id_list:
            return ["No directly related papers found. Try broadening your search query."]
        handle = Entrez.efetch(db="pubmed", id=id_list, rettype="medline", retmode="text")
        records = Medline.parse(handle)
        related_papers = []
        for r in records:
            title = r.get("TI", "")
            authors = ", ".join(r.get("AU", []))
            citation = f"{authors}. {title}. {r.get('SO', '')}"
            url = f"https://pubmed.ncbi.nlm.nih.gov/{r['PMID']}/"
            related_papers.append(f"[{citation}]({url})")
        if not related_papers:
            return ["No directly related papers found. Try broadening your search query."]
        return related_papers
    except Exception as e:
        print(f"Error occurred while searching for related papers: {e}")
        return ["An error occurred while searching for related papers. Please try again later."]

async def generate_related_questions(retrieved_results, num_questions=2, max_tokens=50):
    """
    Generate related questions based on the provided retrieved document context.
    """
    from langchain_openai import OpenAI  # Ensure OpenAI is available
    llm = OpenAI(temperature=0.7)
    prompt_template = PromptTemplate(
        input_variables=["context"],
        template="Given the following context, generate {num_questions} related questions:\n\nContext: {context}\n\nQuestions:",
    )
    chain = LLMChain(llm=llm, prompt=prompt_template)
    context = " ".join([doc.page_content for doc in retrieved_results])
    generated_questions = chain.run(context=context, num_questions=num_questions, max_tokens=max_tokens)
    related_questions = [
        question.split(". ", 1)[-1] for question in generated_questions.split("\n") if question.strip()
    ]
    return related_questions


In [12]:
# Cell 4: Function to Generate Answer

async def generate_answer(query):
    """
    Generate an answer using a conversational retrieval chain.
    Returns a tuple:
      (formatted_answer, text_elements, related_question_actions, related_papers, source_actions, original_query)
    """
    message_history = ChatMessageHistory()
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key="answer",
        chat_memory=message_history,
        return_messages=True,
    )
    chain = ConversationalRetrievalChain.from_llm(
        ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True),
        chain_type="stuff",
        retriever=qdrant_vectorstore.as_retriever(),
        memory=memory,
        return_source_documents=True,
    )
    try:
        cb = cl.AsyncLangchainCallbackHandler()
        feedback_callback = EvaluatorCallbackHandler(
            evaluators=[
                PharmAssistEvaluator(),
                HarmfulnessEvaluator(),
                AIDetectionEvaluator(),
            ]
        )
        res = await chain.acall(query, callbacks=[cb, feedback_callback])
        answer = res["answer"]
        source_docs = res["source_documents"]

        # Remove triple backticks to avoid code blocks (thus no copy button)
        answer = answer.replace("```", "")

        if answer.lower().startswith("i don't know") or answer.lower().startswith("i don't have enough information"):
            return answer, [], [], [], [], query

        formatted_answer = f"**Answer:**\n{answer}"

        # Create clickable source buttons
        source_actions = []
        if source_docs:
            for i, doc in enumerate(source_docs):
                source_actions.append(
                    cl.Action(
                        name="show_source",
                        label=f"Source {i+1}",
                        payload={"source_content": doc.page_content}
                    )
                )

        # Generate related question buttons
        related_questions = await generate_related_questions(source_docs)
        related_question_actions = [
            cl.Action(
                name="related_question",
                label=q.strip(),
                payload={"question": q.strip()}
            )
            for q in related_questions if q.strip()
        ]

        # Get related PubMed papers (as markdown formatted strings)
        related_papers = await search_related_papers(query)

        return formatted_answer, [], related_question_actions, related_papers, source_actions, query

    except Exception as e:
        print(f"Error occurred: {e}")
        return (
            "An error occurred while processing your request. Please try again later.",
            [],
            [],
            [],
            [],
            query,
        )


In [13]:
# Cell 5: Action Callbacks

@cl.action_callback("show_source")
async def on_show_source(action: cl.Action):
    """
    When a source button is clicked, display the full source details.
    We use cl.Text with copyable=False to remove the copy button.
    """
    source_content = action.payload["source_content"].replace("```", "")
    text_elem = cl.Text(content=f"**Source Details:**\n\n{source_content}", copyable=False, markdown=True)
    await cl.Message(elements=[text_elem], author="PrescriptionIQ").send()

@cl.action_callback("related_question")
async def on_related_question_selected(action: cl.Action):
    """
    Handle a related question selection.
    """
    question = action.payload["question"]
    user_elem = cl.Text(content=question, copyable=False, markdown=True)
    await cl.Message(elements=[user_elem], author="User").send()

    ans, txt_elems, rel_q_actions, rel_papers, src_actions, orig_q = await generate_answer(question)
    answer_elem = cl.Text(content=ans, copyable=False, markdown=True)
    await cl.Message(elements=[answer_elem], author="PrescriptionIQ").send()

    if rel_q_actions:
        rq_elem = cl.Text(content="**Related Questions:**", copyable=False, markdown=True)
        await cl.Message(elements=[rq_elem], actions=rel_q_actions, author="PrescriptionIQ").send()

    if rel_papers:
        papers_md = "**Related Papers from PubMed:**\n" + "\n".join(f"- {p}" for p in rel_papers)
        pm_elem = cl.Text(content=papers_md, copyable=False, markdown=True)
        await cl.Message(elements=[pm_elem], author="PrescriptionIQ").send()

    if src_actions:
        src_elem = cl.Text(content="**Sources:** Click to expand.", copyable=False, markdown=True)
        await cl.Message(elements=[src_elem], actions=src_actions, author="PrescriptionIQ").send()

@cl.action_callback("ask_question")
async def on_question_selected(action: cl.Action):
    """
    Respond to a selected question from the suggestions.
    """
    question = action.payload["question"]
    user_elem = cl.Text(content=question, copyable=False, markdown=True)
    await cl.Message(elements=[user_elem], author="User").send()

    ans, txt_elems, rel_q_actions, rel_papers, src_actions, orig_q = await generate_answer(question)
    answer_elem = cl.Text(content=ans, copyable=False, markdown=True)
    await cl.Message(elements=[answer_elem], author="PrescriptionIQ").send()

    if rel_q_actions:
        rq_elem = cl.Text(content="**Related Questions:**", copyable=False, markdown=True)
        await cl.Message(elements=[rq_elem], actions=rel_q_actions, author="PrescriptionIQ").send()

    if rel_papers:
        papers_md = "**Related Papers from PubMed:**\n" + "\n".join(f"- {p}" for p in rel_papers)
        pm_elem = cl.Text(content=papers_md, copyable=False, markdown=True)
        await cl.Message(elements=[pm_elem], author="PrescriptionIQ").send()

    if src_actions:
        src_elem = cl.Text(content="**Sources:** Click to expand.", copyable=False, markdown=True)
        await cl.Message(elements=[src_elem], actions=src_actions, author="PrescriptionIQ").send()


In [14]:
# Cell 6: on_chat_start Callback

@cl.on_chat_start
async def on_chat_start():
    """
    Initialize the chatbot environment, load Qdrant data, and present initial suggested questions.
    """
    global qdrant_vectorstore
    loading_elem = cl.Text(content="**Loading PrescriptionIQ bot...**", copyable=False, markdown=True)
    await cl.Message(elements=[loading_elem]).send()
    await asyncio.sleep(2)

    if qdrant_vectorstore is None:
        embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
        QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
        QDRANT_CLUSTER_URL = os.environ.get("QDRANT_CLUSTER_URL")
        qdrant_client = AsyncQdrantClient(url=QDRANT_CLUSTER_URL, api_key=QDRANT_API_KEY, timeout=60)
        response = await qdrant_client.get_collections()
        collection_names = [c.name for c in response.collections]

        if "fda_drugs" not in collection_names:
            print("Collection 'fda_drugs' is not present.")
            url = "https://download.open.fda.gov/drug/label/drug-label-0001-of-0012.json.zip"
            resp = requests.get(url)
            zip_file = zipfile.ZipFile(io.BytesIO(resp.content))
            json_file = zip_file.open(zip_file.namelist()[0])
            data = json.load(json_file)
            df = pd.json_normalize(data["results"])

            metadata_fields = [
                "openfda.brand_name",
                "openfda.generic_name",
                "openfda.manufacturer_name",
                "openfda.product_type",
                "openfda.route",
                "openfda.substance_name",
                "openfda.rxcui",
                "openfda.spl_id",
                "openfda.package_ndc",
            ]
            text_fields = [
                "description",
                "indications_and_usage",
                "contraindications",
                "warnings",
                "adverse_reactions",
                "dosage_and_administration",
            ]
            df[text_fields] = df[text_fields].fillna("")
            df["content"] = df[text_fields].apply(lambda x: " ".join(x.astype(str)), axis=1)

            loader = DataFrameLoader(df, page_content_column="content")
            drug_docs = loader.load()

            for doc, row in zip(drug_docs, df.to_dict(orient="records")):
                md = {}
                for f in metadata_fields:
                    val = row.get(f)
                    if isinstance(val, list):
                        val = ", ".join(str(v) for v in val if pd.notna(v))
                    elif pd.isna(val):
                        val = "Not Available"
                    md[f] = val
                doc.metadata = md

            text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
            split_drug_docs = text_splitter.split_documents(drug_docs)

            qdrant_vectorstore = await cl.make_async(Qdrant.from_documents)(
                split_drug_docs,
                embedding_model,
                url=QDRANT_CLUSTER_URL,
                api_key=QDRANT_API_KEY,
                collection_name="fda_drugs",
            )
        else:
            print("Collection 'fda_drugs' is present.")
            qdrant_vectorstore = await cl.make_async(Qdrant.construct_instance)(
                texts=[""],
                embedding=embedding_model,
                url=QDRANT_CLUSTER_URL,
                api_key=QDRANT_API_KEY,
                collection_name="fda_drugs",
            )

    potential_questions = [
        "What should I be careful of when taking Metformin?",
        "What are the contraindications of Aspirin?",
        "Are there low-cost alternatives to branded Aspirin available over-the-counter?",
        "What precautions should I take if I'm pregnant or nursing while on Lipitor?",
        "Should Lipitor be taken at a specific time of day, and does it need to be taken with food?",
        "What is the recommended dose of Aspirin?",
        "Can older people take beta blockers?",
        "How do beta blockers work?",
        "Can beta blockers be used for anxiety?",
        "I am taking Aspirin, is it ok to take Glipizide?",
        "Explain in simple terms how Metformin works?",
    ]
    welcome_elem = cl.Text(content="**Welcome to PrescriptionIQ!** Here are some potential questions you can ask:", copyable=False, markdown=True)
    await cl.Message(
        elements=[welcome_elem],
        actions=[cl.Action(name="ask_question", label=q, payload={"question": q}) for q in potential_questions],
    ).send()
    cl.user_session.set("potential_questions_shown", True)


In [28]:
# Cell 7: on_message Callback

@cl.on_message
async def main(message):
    """
    Process free text user input and generate a formatted answer.
    """
    query = message.content.strip()
    if not query:
        no_query_elem = cl.Text(content="Please enter a valid question.", copyable=False, markdown=True)
        await cl.Message(elements=[no_query_elem], author="PrescriptionIQ").send()
        return

    try:
        answer, txt_elems, rel_q_actions, rel_papers, src_actions, orig_q = await generate_answer(query)
        if not answer:
            answer = "Sorry, I couldn't generate an answer. Please try rephrasing your question."

        answer = answer.replace("```", "")
        final_msg = f"{answer}\n\n---\n*Your question: {orig_q}*"
        answer_elem = cl.Text(content=final_msg, copyable=False, markdown=True)
        await cl.Message(elements=[answer_elem], author="PrescriptionIQ").send()

        if rel_q_actions:
            rq_elem = cl.Text(content="**Related Questions:**", copyable=False, markdown=True)
            await cl.Message(elements=[rq_elem], actions=rel_q_actions, author="PrescriptionIQ").send()

        if rel_papers:
            papers_md = "**Related Papers from PubMed:**\n" + "\n".join(f"- {p}" for p in rel_papers)
            pm_elem = cl.Text(content=papers_md, copyable=False, markdown=True)
            await cl.Message(elements=[pm_elem], author="PrescriptionIQ").send()

        if src_actions:
            src_elem = cl.Text(content="**Sources:** Click to expand.", copyable=False, markdown=True)
            await cl.Message(elements=[src_elem], actions=src_actions, author="PrescriptionIQ").send()

    except Exception as e:
        print(f"Error occurred: {e}")
        error_elem = cl.Text(content="An error occurred while processing your request. Please try again later.", copyable=False, markdown=True)
        await cl.Message(elements=[error_elem], author="PrescriptionIQ").send()



In [30]:
!chainlit run PrescriptionIQ.py -w
# chainlit run app.py -p 5000


2025-03-22 23:16:02 - 1 change detected
2025-03-22 23:16:03 - 1 change detected
^C


In [27]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from mpl_toolkits.mplot3d import Axes3D  # This import registers the 3D projection, no need to call it directly.


In [21]:
# Replace this with your actual embeddings
n_samples = 100
embedding_dim = 768  # Example dimension, adjust as needed.
embeddings = np.random.randn(n_samples, embedding_dim)


In [None]:
# Perform PCA to reduce to 2 dimensions
pca2d = PCA(n_components=2)
embeddings_2d = pca2d.fit_transform(embeddings)

# Create a 2D scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.7)
plt.title("2D Visualization of Embeddings (PCA)")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.show()


In [None]:
# Perform PCA to reduce to 3 dimensions
pca3d = PCA(n_components=3)
embeddings_3d = pca3d.fit_transform(embeddings)

# Create a 3D scatter plot
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(embeddings_3d[:, 0], embeddings_3d[:, 1], embeddings_3d[:, 2], alpha=0.7)
ax.set_title("3D Visualization of Embeddings (PCA)")
ax.set_xlabel("Principal Component 1")
ax.set_ylabel("Principal Component 2")
ax.set_zlabel("Principal Component 3")
plt.show()
