**Authors:**

- Ravi Teja Kothuru (Primary)
- Soumi Ray
- Anwesha Sarangi

**Title of the Project:** SmartChat: A Context-Aware Conversational Agent

**Description of the Project:** Develop a chatbot that can effectively adapt to context and topic shifts in a conversation, leveraging the Stanford Question Answering Dataset to provide informed and relevant responses, and thereby increasing user satisfaction and engagement.

**Objectives of the Project:** Create a user-friendly web or app interface that enables users to have natural and coherent conversations with the chatbot, with high satisfaction rating.

**Name of the Dataset:** Stanford Question Answering Dataset

**Description of the Dataset:** The Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset consisting of questions posed by crowdworkers on a set of Wikipedia articles. The answer to every question is a segment of text, or span, from the corresponding reading passage. There are 100,000+ question-answer pairs on 500+ articles. More information can be found at: https://rajpurkar.github.io/SQuAD-explorer/

**Dataset Source:**

Kaggle (https://www.kaggle.com/datasets/stanfordu/stanford-question-answering-dataset)

***Number of Variables in Dataset:*** There are 2 variables in this dataset

- data
- version

Each of these have other variables such as:

- ***context:*** A lengthy paragraph that has some information.
- ***question:*** A question based on the context.
- ***answer:*** An answer to the context from the context.
- ***ans_start:*** The index value of context where the answer to the question is started.
- ***ans_end:*** The index value of context where the answer to the question is ended.

***Size of the Dataset:*** The dataset has 2 JSON files. One is for training and the other is for validation

- Training Dataset's filename is train-v1.1.json and it size is 30.3 MB.
- Validation Dataset's filename is dev-v1.1.json and it size is 4.9 MB.

In [1]:
# Import necessary libraries
!pip install gradio
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoConfig
from peft import AutoPeftModelForCausalLM
import gc


# Function to load the tokenizer and model
def load_model_and_tokenizer(model_path):
    """
    Loads the tokenizer and the fine-tuned LoRA model.

    Args:
        model_path (str): Path to the fine-tuned LoRA model.

    Returns:
        tokenizer, model, device
    """
    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Loading model on device: {device}")

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-medium")

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

    # Load fine-tuned LoRA model
    model = AutoPeftModelForCausalLM.from_pretrained(model_path)
    model.to(device)
    model.eval()

    model.generation_config.pad_token_id = tokenizer.pad_token_id

    # Clear cache
    gc.collect()
    torch.cuda.empty_cache()

    return tokenizer, model, device


# Load the model and tokenizer
tokenizer, model, device = load_model_and_tokenizer("gpt2-medium-lora")


# Function to generate answers
def generate_answer(context, question):
    """
    Generates an answer based on the provided context and question.

    Args:
        context (str): The context for the questions.
        question (str): The user's question.

    Returns:
        str: Generated answer.
    """
    try:
        # Define the prompt format as per training
        prompt = f"Context: {context}\nQuestion: {question}\n <|start_answer|>"

        # Tokenize the input
        inputs = tokenizer(prompt, padding="max_length", truncation=True, max_length=800, return_tensors='pt')

        # Generate output
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=32,
                do_sample=True,
                top_k=50,
                top_p=0.95,
                temperature=0.2,
                num_return_sequences=1,
            )
        # Decode the output
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Extract the answer
        start_token = "<|start_answer|>"
        start_idx = generated_text.find(start_token) + len(start_token)
        if start_idx != -1:
            generated_text = generated_text[start_idx:].strip()

        question_index = generated_text.find("Question:")
        
        if question_index != -1:
            generated_text = generated_text[0:question_index].strip()
        return generated_index

    except Exception as e:
        print(f"Error during generation: {e}")
        return "‚ùå An error occurred while generating the answer."


# Gradio Interface
def chatbot_interface():
    """
    Create the Gradio interface for the chatbot.

    Returns:
        gr.Blocks: A Gradio interface with context-aware chatbot functionality.
    """
    with gr.Blocks() as demo:
        # Adding custom CSS for beautifying the interface
        gr.Markdown("""
            <style>
                body {
                    background-color: #f0f0f0;  /* Light gray background */
                }
                .chatbot-container {
                    background-color: #ffffff;  /* White background for chatbot area */
                    border-radius: 10px;
                    padding: 20px;
                    color: #333;  /* Dark text color */
                    font-family: Arial, sans-serif;
                }
                .gr-button {
                    background-color: #4CAF50;  /* Green button */
                    color: white;
                    border: none;
                    border-radius: 5px;
                    padding: 10px 20px;
                    font-size: 14px;
                    cursor: pointer;
                }
                .gr-button:hover {
                    background-color: #45a049;  /* Darker green on hover */
                }
                .gr-textbox {
                    background-color: #ffffff;  /* White background for textboxes */
                    color: #333;  /* Dark text color in textbox */
                    border-radius: 5px;
                    border: 1px solid #ddd;
                    padding: 10px;
                }
                .gr-chatbot {
                    background-color: #e6e6e6;  /* Light gray background for chatbot */
                    border-radius: 10px;
                    padding: 15px;
                    color: #333;
                }
                .status-message {
                    color: #007bff;  /* Blue status message */
                    font-weight: bold;
                }
                .footer {
                    text-align: right;
                    font-size: 12px;
                    color: #777;
                    font-style: italic;
                }
            </style>
        """)

        # State to store context
        gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>üß† SmartChat: A Context-Aware Conversational Agent</h1>")
        gr.Markdown("<p style='text-align: center; color: #777;'>Set a context and then ask multiple questions based on that context.</p>")
        context_state = gr.State()

        with gr.Row():
            with gr.Column(scale=1):
                # Context input
                context_input = gr.Textbox(
                    label="Set Context",
                    placeholder="Enter the context here...",
                    lines=4
                )
                set_context_btn = gr.Button("Set Context")

                # Clear Context button
                clear_context_btn = gr.Button("Clear Context")

                # Status message
                status_message = gr.Markdown("")

            with gr.Column(scale=2):
                # Chatbot display
                chatbot = gr.Chatbot(label="Chatbot")

        # Question input
        question_input = gr.Textbox(
            label="Ask a Question",
            placeholder="Enter your question here...",
            lines=1
        )
        submit_btn = gr.Button("Submit Question")

        footer = gr.Markdown("""
            <div style='display: flex; justify-content: space-between; font-size: 12px; color: #777;'>
                <p style='margin: 0;'>Trained using: GPT2-Medium LORA</p>
                <p style='margin: 0;'>Prepared by: Ravi Teja Kothuru, Soumi Ray and Anwesha Sarangi</p>
            </div>
        """)

        # Function to set context
        def set_context(context):
            """
            Set the provided context for future question-answering.

            Args:
                context (str): The context to set.

            Returns:
                tuple: A tuple of updated UI components after setting the context.
            """
            if not context.strip():
                return gr.update(), "Please enter a valid context.", None
            return gr.update(visible=False), "Context has been set. You can now ask questions.", context

        # Function to clear context
        def clear_context():
            """
            Clear the current context.

            Returns:
                tuple: A tuple of updated UI components after clearing the context.
            """
            return gr.update(visible=True), "Context has been cleared. Please set a new context.", None

        # Function to handle question submission
        def handle_question(question, history, context):
            """
            Handle the question by generating an answer based on the context.

            Args:
                question (str): The question to answer.
                history (list): The conversation history.
                context (str): The context for generating the answer.

            Returns:
                tuple: Updated conversation history and the cleared question input.
            """
            if not context:
                return history, "Please set the context before asking questions."
            if not question.strip():
                return history, "Please enter a valid question."

            answer = generate_answer(context, question)
            history = history + [[f"üë§ : {question}", f"ü§ñ : {answer}"]]
            return history, ""

        # Event bindings
        set_context_btn.click(set_context, inputs=context_input, outputs=[context_input, status_message, context_state])
        clear_context_btn.click(clear_context, inputs=None, outputs=[context_input, status_message, context_state])
        submit_btn.click(
            handle_question,
            inputs=[question_input, chatbot, context_state],
            outputs=[chatbot, question_input]
        )

    return demo


# Launch the Gradio app
if __name__ == "__main__":
    demo = chatbot_interface()
    demo.launch()


Loading model on device: cpu




* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
