In [None]:
from langchain.schema import SystemMessage
from typing import Type
from pydantic import BaseModel
from pydantic import Field
from langchain.chat_models import ChatOpenAI
from langchain.tools import BaseTool
from langchain.agents import initialize_agent
from langchain.agents import AgentType
from langchain.utilities import DuckDuckGoSearchAPIWrapper
from langchain.utilities import WikipediaAPIWrapper
from langchain.prompts import PromptTemplate
from langchain.document_loaders import WebBaseLoader
from langchain.schema.runnable import RunnablePassthrough
from datetime import datetime
import streamlit as st
import os
import requests

llm = ChatOpenAI(
    temperature=0.1,
)


class WikipediaSearchTool(BaseTool):

    name = "WikipediaSearchTool"
    description = """
    Use this tool to find the website for the given query.
    """

    class WikipediaSearchToolArgsSchema(BaseModel):
        query: str = Field(
            description="The query you will search for. Example query: Research about the XZ backdoor",
        )

    args_schema: Type[WikipediaSearchToolArgsSchema] = WikipediaSearchToolArgsSchema

    def _run(self, query):
        w = WikipediaAPIWrapper()
        return w.run(query)


class DuckDuckGoSearchTool(BaseTool):

    name = "DuckDuckGoTool"
    description = """
    Use this tool to find the website for the given query.
    """

    class DuckDuckGoSearchToolArgsSchema(BaseModel):
        query: str = Field(
            description="The query you will search for. Example query: Research about the XZ backdoor",
        )

    args_schema: Type[DuckDuckGoSearchToolArgsSchema] = DuckDuckGoSearchToolArgsSchema

    def _run(self, query):
        ddg = DuckDuckGoSearchAPIWrapper()
        return ddg.run(query)


class LoadWebsiteTool(BaseTool):

    name = "LoadWebsiteTool"
    description = """
    Use this tool to load the website for the given url.
    """

    class LoadWebsiteToolArgsSchema(BaseModel):
        url: str = Field(
            description="The url you will load. Example url: https://en.wikipedia.org/wiki/Backdoor_(computing)",
        )

    args_schema: Type[LoadWebsiteToolArgsSchema] = LoadWebsiteToolArgsSchema

    def _run(self, url):
        loader = WebBaseLoader([url])
        docs = loader.load()
        # transformer = Html2TextTransformer.transform_documents(docs)
        # print(docs)
        # with open("./outputs/research.txt", "w") as f:
        #     f.write(docs.page_content)
        return docs


class SaveToFileTool(BaseTool):
    name = "SaveToFileTool"
    description = """
    Use this tool to save the text to a file.
    """

    class SaveToFileToolArgsSchema(BaseModel):
        text: str = Field(
            description="The text you will save to a file.",
        )
        file_path: str = Field(
            description="Path of the file to save the text to.",
        )

    args_schema: Type[SaveToFileToolArgsSchema] = SaveToFileToolArgsSchema

    def _run(self, text, file_path):
        rearch_dt = datetime.now().strftime("%Y%m%d_%H%M%S")
        with open(f"{rearch_dt}_{file_path}", "w", encoding="utf-8") as f:
            f.write(text)
        return f"Text saved to {rearch_dt}_{file_path}"


def agent_invoke(input):

    agent = initialize_agent(
        llm=llm,
        verbose=True,
        agent=AgentType.OPENAI_FUNCTIONS,
        handle_parsing_errors=True,
        tools=[
            WikipediaSearchTool(),
            DuckDuckGoSearchTool(),
            LoadWebsiteTool(),
            SaveToFileTool(),
        ],
    )

    prompt = PromptTemplate.from_template(
        """    
        Search for the query

        If there is a list of website URLs in the search results, extract the content of each website as text

        Save the content as a .txt file
        
        query: {query}    
        """,
    )

    chain = {"query": RunnablePassthrough()} | prompt | agent
    chain.invoke(input)


query = "Research about the XZ backdoor"

agent_invoke(query)

In [None]:
# with st.sidebar:
#     if "api_key" not in st.session_state:
#         st.session_state["api_key"] = ""

#     api_key_input = st.empty()

#     # def reset_api_key():
#     #     st.session_state["api_key"] = ""
#     #     print(st.session_state["api_key"])

#     # if st.button(":red[Remove Key]"):
#     #     reset_api_key()

#     api_key = api_key_input.text_input(
#         "**:blue[Enter your OpenAI Key:]**",
#         value=st.session_state["api_key"],
#         key="api_key_input",
#     )

#     if api_key != st.session_state["api_key"]:
#         st.session_state["api_key"] = api_key
#         st.rerun()

#     st.text("")
#     st.text("")

#     url = st.text_input(
#         "**:blue[Write down a URL]**",
#         placeholder="https://example.com",
#         value="https://developers.cloudflare.com/sitemap.xml",
#     )
#     url_name = url.split("://")[1].split("/")[0] if url else None

#     st.text("")
#     st.text("")

#     st.markdown(
#         """
#         GitHub Repo: https://github.com/jundev5796/fullstack-gpt/commit/395c6f63c76029c40ea1bc3a8ecfabd2e5f643d5
#         """
#     )