In [None]:
import os
import random
from dotenv import load_dotenv
import streamlit as st
from langchain_groq import ChatGroq
from langchain_community.embeddings import OllamaEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFDirectoryLoader
import time

# Load environment variables
load_dotenv()
groq_api_key = os.getenv('GROQ_API_KEY')

# Initialize LLM
llm = ChatGroq(model="llama3-8b-8192", groq_api_key=groq_api_key)

# Define the prompt template for generating IELTS writing tasks
prompt = ChatPromptTemplate.from_template(
    """Based on the content provided, generate a random IELTS Writing Task 2 prompt.
    Do not include any answers or explanations, just the task prompt.
    
    <context>
    {context}
    </context>

    Task: Generate an IELTS Writing Task 2 prompt."""
)

# Function to create vector embeddings and load documents
def create_vector_embedding():
    st.session_state.embeddings = OllamaEmbeddings()
    st.session_state.loader = PyPDFDirectoryLoader("task2_writing")  # Data ingestion
    st.session_state.docs = st.session_state.loader.load()  # Document loading
    st.session_state.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
    st.session_state.final_documents = st.session_state.text_splitter.split_documents(st.session_state.docs)
    st.session_state.vectors = FAISS.from_documents(st.session_state.final_documents, st.session_state.embeddings)
    # Extract writing tasks from the documents and store them
    st.session_state.writing_tasks = extract_writing_tasks(st.session_state.final_documents)

# Function to extract IELTS writing tasks from documents
def extract_writing_tasks(documents):
    tasks = []
    for doc in documents:
        if "Task 2" in doc.page_content:  # Assuming that "Task 2" indicates a writing prompt
            tasks.append(doc.page_content)
    return tasks

# Function to generate a random question using the LLM
def generate_llm_question():
    if "vectors" in st.session_state:
        # Use the LLM to generate a new random question
        document_chain = create_stuff_documents_chain(llm, prompt)
        retriever = st.session_state.vectors.as_retriever()
        retrieval_chain = create_retrieval_chain(retriever, document_chain)
        response = retrieval_chain.invoke({'input': 'Generate a random IELTS Writing Task'})
        return response['answer']
    else:
        return "Vector embeddings are not yet ready."

# Initialize embeddings and vector database on app start
if "vectors" not in st.session_state:
    create_vector_embedding()

# Streamlit UI
st.title("IELTS Writing Task Generator")

# Display a single random Task 2 question and update only on button click
if 'question_generated' not in st.session_state:
    st.session_state.question_generated = False

def update_question():
    st.session_state.current_task = generate_llm_question()
    st.session_state.question_generated = True

if st.button("Generate Random Writing Task"):
    update_question()

if st.session_state.question_generated:
    st.write("Random IELTS Writing Task 2 Question:")
    st.write(st.session_state.current_task)

# Allow users to input custom queries
user_prompt = st.text_input("Enter your query:")

if user_prompt:
    document_chain = create_stuff_documents_chain(llm, prompt)
    retriever = st.session_state.vectors.as_retriever()
    retrieval_chain = create_retrieval_chain(retriever, document_chain)

    start = time.process_time()
    response = retrieval_chain.invoke({'input': user_prompt})
    st.write(f"Response time: {time.process_time() - start}")
    st.write(response['answer'])

    with st.expander("Document similarity search"):
        for i, doc in enumerate(response['context']):
            st.write(doc.page_content)
            st.write('--------------')