In [None]:
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langchain_openai import OpenAI, ChatOpenAI
import wikipedia
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain.memory import ConversationBufferMemory

In [None]:
api_key = open('api_key.txt').read()
llm = OpenAI(api_key = api_key)

template = """"You are a helpful and compassionate chatbot explaining about mental health disorders to employees in the tech industry at potential risk.
They answered questions on a survey, and have been found to be at risk for {disorder}. 

Context: {disorder_context}

(if the context is irrelevant, ignore it and answer to the best of your abilities)

Information about the person: 
{survey}

Their Past Prompts: {past_qs}

Their Question: {question}

Answer: """

prompt = PromptTemplate.from_template(template)

In [None]:
disorder_context = {}

for disorder in ['Anxiety Disorder', 'Mood Disorder']:
    text = ""
    pages = wikipedia.search(disorder)
    for p in pages:
        try: text += wikipedia.page(p).content + '\n\n'
        except: continue
        
    disorder_context[disorder] = text

In [None]:
text_splitter = CharacterTextSplitter(chunk_size=350, 
                                      chunk_overlap=0, separator = '.')
chunked_texts = text_splitter.split_text(disorder_context['Anxiety Disorder'])

db = FAISS.from_texts(chunked_texts, OpenAIEmbeddings(api_key = api_key))
db.save_local('anxiety')

In [None]:
# test query
db.similarity_search("I have paranoia", k = 2)

In [None]:
llm_chain = LLMChain(prompt=prompt, llm=llm)
question = input(">> ")
past_questions = []

while 'quit' not in question:

    disorder = "anxiety"
    db = FAISS.load_local(f'faiss_databases/{disorder}', OpenAIEmbeddings(api_key = api_key))

    context = ""
    for t in db.similarity_search(question):
        context += t.page_content + "\n"

    memory_content = ""
    for i in past_questions:
        memory_content += i + "\n"

    ans = llm_chain.invoke({'disorder': disorder, 'disorder_context': context, 'survey': "", 'question': question, 'past_qs': memory_content})['text']
    print(ans)

    past_questions.append(question)
    if (len(past_questions) > 5): past_questions.pop(0)

    question = input(">> ")