In [71]:
!pip install google-generativeai langchain-google-genai langchain-core langchain-community langchain-text-splitters chromadb pymupdf streamlit pyngrok



In [97]:
%%writefile app.py
import streamlit as st
import os
import tempfile
import fitz
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
import google.generativeai as genai
import streamlit.components.v1 as components


class RAGPipeline:
    def __init__(self, api_key):
        os.environ["GOOGLE_API_KEY"] = api_key
        genai.configure(api_key=api_key)
        self.embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
        self.llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash", temperature=0.3)
        self.vector_store = None
        self.retriever = None
        self.full_text = ""

    def ingest_pdf(self, pdf_path):
        doc = fitz.open(pdf_path)
        self.full_text = ""
        for page in doc:
            self.full_text += page.get_text()
        doc.close()

        if len(self.full_text) < 100:
            return 0, "Error: Could not extract text from PDF"

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=500,
            chunk_overlap=100,
            separators=["\n\n", "\n", ". ", " ", ""]
        )
        chunks = text_splitter.split_text(self.full_text)

        if not chunks:
            return 0, "Error: No chunks created"

        self.vector_store = Chroma.from_texts(
            texts=chunks,
            embedding=self.embeddings
        )

        self.retriever = self.vector_store.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 8}
        )
        return len(chunks), "Success"

    def format_docs(self, docs):
        if not docs:
            return "No relevant content found."
        return "\n\n---\n\n".join(doc.page_content for doc in docs)

    def answer_query(self, query):
        if not self.retriever:
            return "Please upload a document first."

        docs = self.retriever.invoke(query)
        context = self.format_docs(docs)

        if len(docs) == 0:
            return "No relevant content found in the document for your query."

        prompt = ChatPromptTemplate.from_template("""
        Based on the following context from a document, answer the question.
        Use only the information provided in the context.
        If the information is not available, say so clearly.

        Context:
        {context}

        Question: {question}

        Answer:
        """)

        chain = prompt | self.llm | StrOutputParser()

        response = chain.invoke({
            "context": context,
            "question": query
        })

        return response

    def generate_mindmap(self, topic=None):
        """Generate a mind map of the entire document or specific topic."""
        if not self.retriever:
            return "Please upload a document first."

        if topic:
            docs = self.retriever.invoke(topic)
            context = self.format_docs(docs)
            task = f"Create a mind map about: {topic}"
        else:
            context = self.full_text[:8000]  # Limit context size
            task = "Create a comprehensive mind map of the entire document"

        prompt = f"""
        Based on the text below, generate a Mermaid JS mind map code.

        STRICT RULES:
        1. Use 'mindmap' as the diagram type
        2. Use proper indentation for hierarchy
        3. Keep labels short and clear
        4. Return ONLY the raw mermaid code, no markdown wrapper
        5. Do not use special characters like brackets or parentheses in labels

        Format example:
        mindmap
          root((Main Topic))
            Branch1
              Sub1
              Sub2
            Branch2
              Sub3
              Sub4

        Context: {context}

        Task: {task}
        """

        response = self.llm.invoke(prompt)
        clean_code = response.content.strip()
        clean_code = clean_code.replace("```mermaid", "").replace("```", "").strip()
        return clean_code

    def generate_flowchart(self, topic=None):
        """Generate a flowchart of the document or specific topic."""
        if not self.retriever:
            return "Please upload a document first."

        if topic:
            docs = self.retriever.invoke(topic)
            context = self.format_docs(docs)
            task = f"Create a flowchart about: {topic}"
        else:
            context = self.full_text[:8000]
            task = "Create a flowchart showing the main process or structure in the document"

        prompt = f"""
        Based on the text below, generate a Mermaid JS flowchart code (graph TD).

        STRICT RULES:
        1. Use graph TD (top-down) format
        2. Enclose ALL node labels in double quotes: id["Label"]
        3. Keep labels concise
        4. Return ONLY raw mermaid code, no markdown
        5. Do not use special characters inside quotes

        Format example:
        graph TD
            A["Start"] --> B["Step 1"]
            B --> C["Step 2"]
            C --> D["End"]

        Context: {context}

        Task: {task}
        """

        response = self.llm.invoke(prompt)
        clean_code = response.content.strip()
        clean_code = clean_code.replace("```mermaid", "").replace("```", "").strip()
        return clean_code

    def generate_summary(self):
        """Generate a summary of the document."""
        if not self.full_text:
            return "Please upload a document first."

        prompt = f"""
        Provide a comprehensive summary of the following document.
        Include:
        1. Main topic/purpose
        2. Key points
        3. Important findings or conclusions

        Document:
        {self.full_text[:10000]}

        Summary:
        """

        response = self.llm.invoke(prompt)
        return response.content


def render_mermaid(code, height=500):
    """Render mermaid diagram in Streamlit."""
    html_code = f"""
    <div class="mermaid" style="background-color: white; padding: 20px; border-radius: 10px;">
        {code}
    </div>
    <script type="module">
        import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
        mermaid.initialize({{
            startOnLoad: true,
            theme: 'default',
            securityLevel: 'loose'
        }});
    </script>
    """
    components.html(html_code, height=height, scrolling=True)


# STREAMLIT UI

st.set_page_config(
    page_title="AI Research Assistant",
    page_icon="ðŸ“„",
    layout="wide"
)

# Custom CSS
st.markdown("""
<style>
    .stTabs [data-baseweb="tab-list"] {
        gap: 24px !important;
    }
    .stTabs [data-baseweb="tab"] {
        height: 40px !important;
        padding-left: 20px;
        padding-right: 20px;
        font-size: 30px;
    }
    .main-header {
        font-size: 50px !important;
        font-weight: bold;
        color: #1E88E5;
        margin-bottom: 20px;
    }
</style>
""", unsafe_allow_html=True)

st.markdown('<p class="main-header">ðŸ“„ AI Research Assistant</p>', unsafe_allow_html=True)


# SIDEBAR
with st.sidebar:
    st.header("Configuration")

    api_key = st.text_input("Google API Key", type="password", help="Enter your Google Gemini API key")

    st.divider()

    st.header("Upload Document")
    uploaded_file = st.file_uploader(
        "Upload Research Paper (PDF)",
        type="pdf",
        help="Upload a PDF document to analyze"
    )

    if st.button("Process Document", use_container_width=True):
        if not uploaded_file:
            st.error("Please upload a PDF file")
        elif not api_key:
            st.error("Please enter your API key")
        else:
            with st.spinner("Processing document..."):
                try:
                    with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
                        tmp_file.write(uploaded_file.getvalue())
                        tmp_path = tmp_file.name

                    st.session_state.rag = RAGPipeline(api_key)
                    num_chunks, status = st.session_state.rag.ingest_pdf(tmp_path)
                    os.unlink(tmp_path)

                    if status == "Success":
                        st.session_state.doc_processed = True
                        st.success(f"Indexed {num_chunks} chunks!")
                    else:
                        st.error(f"{status}")
                except Exception as e:
                    st.error(f"Error: {str(e)}")

    if "doc_processed" in st.session_state and st.session_state.doc_processed:
        st.success("Document Ready!")

    st.divider()
    st.markdown("### Instructions")
    st.markdown("""
    1. Enter your Google API key
    2. Upload a PDF document
    3. Click 'Process Document'
    4. Use Q&A or Mind Map tabs
    """)


# MAIN TABS
tab1, tab2, tab3 = st.tabs(
    ["Q&A Chat", "Mind Map Generator", "Document Summary"]
)



# TAB 1: Q&A CHAT
with tab1:
    st.subheader(" Ask Questions About Your Document")

    if "rag" not in st.session_state:
        st.info("Please upload and process a document first using the sidebar.")
    else:
        if "messages" not in st.session_state:
            st.session_state.messages = []

        chat_container = st.container()
        with chat_container:
            for message in st.session_state.messages:
                with st.chat_message(message["role"]):
                    st.markdown(message["content"])

        if prompt := st.chat_input("Ask a question about the document..."):
            st.session_state.messages.append({"role": "user", "content": prompt})
            with st.chat_message("user"):
                st.markdown(prompt)

            with st.chat_message("assistant"):
                with st.spinner("Thinking..."):
                    try:
                        response = st.session_state.rag.answer_query(prompt)
                        st.markdown(response)
                        st.session_state.messages.append({"role": "assistant", "content": response})
                    except Exception as e:
                        st.error(f"Error: {str(e)}")

        col1, col2, col3 = st.columns([1, 1, 2])
        with col1:
            if st.button("Clear Chat", use_container_width=True):
                st.session_state.messages = []
                st.rerun()


# TAB 2: MIND MAP GENERATOR
with tab2:
    st.subheader("Generate Visual Diagrams")

    if "rag" not in st.session_state:
        st.info("Please upload and process a document first using the sidebar.")
    else:
        col1, col2 = st.columns(2)

        with col1:
            diagram_type = st.selectbox(
                "Select Diagram Type",
                ["Mind Map", "Flowchart"],
                help="Choose the type of diagram to generate"
            )

        with col2:
            topic_input = st.text_input(
                "Specific Topic (Optional)",
                placeholder="Leave empty for full document overview",
                help="Enter a specific topic or leave empty for document overview"
            )

        if st.button("Generate Diagram", use_container_width=True, type="primary"):
            with st.spinner(f"Generating {diagram_type}..."):
                try:
                    topic = topic_input if topic_input.strip() else None

                    if diagram_type == "Mind Map":
                        mermaid_code = st.session_state.rag.generate_mindmap(topic)
                    else:
                        mermaid_code = st.session_state.rag.generate_flowchart(topic)

                    st.session_state.current_diagram = mermaid_code
                    st.session_state.diagram_type = diagram_type

                except Exception as e:
                    st.error(f"Error generating diagram: {str(e)}")

        st.divider()


        if "current_diagram" in st.session_state:
            st.subheader(f"Generated {st.session_state.get('diagram_type', 'Diagram')}")


            view_tab1, view_tab2 = st.tabs(["Visual View", "Code View"])

            with view_tab1:
                try:
                    render_mermaid(st.session_state.current_diagram, height=600)
                except Exception as e:
                    st.error(f"Error rendering diagram: {str(e)}")
                    st.code(st.session_state.current_diagram, language="mermaid")

            with view_tab2:
                st.code(st.session_state.current_diagram, language="mermaid")

                # Copy button
                if st.button("Copy Code"):
                    st.write("Code copied! (Use Ctrl+C from the code block above)")

        st.divider()
        st.subheader("Quick Generate")

        quick_col1, quick_col2, quick_col3 = st.columns(3)

        with quick_col1:
            if st.button("Document Overview", use_container_width=True):
                with st.spinner("Generating..."):
                    try:
                        mermaid_code = st.session_state.rag.generate_mindmap(None)
                        st.session_state.current_diagram = mermaid_code
                        st.session_state.diagram_type = "Mind Map"
                        st.rerun()
                    except Exception as e:
                        st.error(str(e))

        with quick_col2:
            if st.button("Process Flow", use_container_width=True):
                with st.spinner("Generating..."):
                    try:
                        mermaid_code = st.session_state.rag.generate_flowchart("main process or methodology")
                        st.session_state.current_diagram = mermaid_code
                        st.session_state.diagram_type = "Flowchart"
                        st.rerun()
                    except Exception as e:
                        st.error(str(e))

        with quick_col3:
            if st.button("Key Concepts", use_container_width=True):
                with st.spinner("Generating..."):
                    try:
                        mermaid_code = st.session_state.rag.generate_mindmap("key concepts and findings")
                        st.session_state.current_diagram = mermaid_code
                        st.session_state.diagram_type = "Mind Map"
                        st.rerun()
                    except Exception as e:
                        st.error(str(e))


# TAB 3: DOCUMENT SUMMARY
with tab3:
    st.subheader("Document Summary")

    if "rag" not in st.session_state:
        st.info("Please upload and process a document first using the sidebar.")
    else:
        if st.button("Generate Summary", use_container_width=True, type="primary"):
            with st.spinner("Analyzing document..."):
                try:
                    summary = st.session_state.rag.generate_summary()
                    st.session_state.doc_summary = summary
                except Exception as e:
                    st.error(f"Error: {str(e)}")

        if "doc_summary" in st.session_state:
            st.markdown("### Summary")
            st.markdown(st.session_state.doc_summary)

            st.divider()

            # Additional analysis buttons
            st.subheader("Quick Analysis")

            analysis_col1, analysis_col2 = st.columns(2)

            with analysis_col1:
                if st.button("Key Takeaways", use_container_width=True):
                    with st.spinner("Extracting..."):
                        response = st.session_state.rag.answer_query(
                            "What are the main takeaways and conclusions from this document? List them as bullet points."
                        )
                        st.markdown("### Key Takeaways")
                        st.markdown(response)

            with analysis_col2:
                if st.button("Research Questions", use_container_width=True):
                    with st.spinner("Analyzing..."):
                        response = st.session_state.rag.answer_query(
                            "What research questions or problems does this document address?"
                        )
                        st.markdown("### Research Questions")
                        st.markdown(response)


st.divider()

Overwriting app.py


In [98]:
from pyngrok import ngrok
import time

ngrok.kill()

Token = "38ZrdgoQXhHzzxUt2otTMSs1gUM_5z9D9pVHqJqoFxYE5LzUC"
ngrok.set_auth_token(Token)

!streamlit run app.py &>/content/logs.txt &

time.sleep(5)

public_url = ngrok.connect(8501)
print("ðŸš€ App running at:", public_url)

ðŸš€ App running at: NgrokTunnel: "https://sclerosed-violet-germless.ngrok-free.dev" -> "http://localhost:8501"
