<a href="https://colab.research.google.com/github/mahipalimkar/RAG-from-Scratch/blob/master/RAG_Assisted_Chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers langchain llama-index faiss-cpu chromadb pypdf sentence-transformers unstructured




In [None]:
from pypdf import PdfReader

def extract_text_from_pdf(pdf_path):
    reader = PdfReader(pdf_path)
    text = ""
    for page in reader.pages:
        text += page.extract_text() + "\n"
    return text

# Example usage
pdf_path = "/constitution.pdf"  # Upload Indian Penal Code PDF to Colab
legal_text = extract_text_from_pdf(pdf_path)

#print(legal_text)


In [None]:
from sentence_transformers import SentenceTransformer

In [None]:
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [None]:
def get_embeddings(text):
    return embedding_model.encode(text, convert_to_tensor=True)

# Example embedding conversion
ipc_embedding = get_embeddings(legal_text)

In [None]:
import faiss
import numpy as np

In [None]:
# Create an index with 384 dimensions (matching MiniLM model)
dimension = 384
index = faiss.IndexFlatL2(dimension)

# Store the document embeddings
ipc_embedding_np = np.array([ipc_embedding.cpu().numpy()])
index.add(ipc_embedding_np)

print(f"Added {index.ntotal} document(s) to FAISS index")

Added 1 document(s) to FAISS index


In [None]:
def search(query, k=1):
    query_embedding = get_embeddings(query).cpu().numpy().reshape(1, -1)
    distances, indices = index.search(query_embedding, k)
    return distances, indices

# Example query
query = "What is the punishment for theft under IPC?"
distances, indices = search(query, k=1)
print("Closest document index:", indices)


Closest document index: [[0]]


In [None]:
from huggingface_hub import notebook_login
notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "TinyLlama/TinyLlama-1.1B-chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)


tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [None]:
if 'chatbot' in locals():
    del chatbot
if 'ui' in locals():
    del ui

In [None]:
from IPython.display import HTML, display
import ipywidgets as widgets
from typing import List, Tuple
import torch
import re

In [None]:
class LegalChatbot:
    def __init__(self, model, tokenizer, index, legal_text):
        self.model = model
        self.tokenizer = tokenizer
        self.index = index
        self.legal_text = legal_text
        self.chat_history = []
        self.max_context_length = 512  # Reduce context length to leave room for new tokens

    def get_relevant_context(self, query: str, num_chunks: int = 2) -> str:  # Reduced number of chunks
        # Split the legal text into smaller chunks for more precise retrieval
        chunks = self._split_into_chunks(self.legal_text, chunk_size=500)  # Reduced chunk size
        chunk_embeddings = []

        # Get embeddings for all chunks
        for chunk in chunks:
            embedding = get_embeddings(chunk).cpu().numpy()
            chunk_embeddings.append(embedding)

        # Convert to numpy array
        chunk_embeddings = np.vstack(chunk_embeddings)

        # Create new FAISS index for chunks
        chunk_index = faiss.IndexFlatL2(dimension)
        chunk_index.add(chunk_embeddings)

        # Search for relevant chunks
        query_embedding = get_embeddings(query).cpu().numpy().reshape(1, -1)
        distances, indices = chunk_index.search(query_embedding, num_chunks)

        # Combine relevant chunks
        relevant_text = " ".join([chunks[i] for i in indices[0]])
        return self._truncate_context(relevant_text)

    def _truncate_context(self, context: str) -> str:
        # Truncate context to ensure we don't exceed token limits
        tokens = self.tokenizer.encode(context)
        if len(tokens) > self.max_context_length:
            tokens = tokens[:self.max_context_length]
            context = self.tokenizer.decode(tokens)
        return context

    def _split_into_chunks(self, text: str, chunk_size: int) -> List[str]:
        # Split text into chunks of approximately equal size
        words = text.split()
        chunks = []
        for i in range(0, len(words), chunk_size):
            chunk = " ".join(words[i:i + chunk_size])
            chunks.append(chunk)
        return chunks

    def generate_response(self, query: str) -> str:
        try:
            # Get relevant context
            context = self.get_relevant_context(query)

            # Construct prompt
            prompt = f"""You are a legal assistant. Use the following context to answer the question.
            If you cannot find relevant information in the context, say so.

            Context: {context}

            Question: {query}

            Answer:"""

            # Tokenize input
            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.max_context_length)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=256,  # Specify max_new_tokens instead of max_length
                    num_return_sequences=1,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.eos_token_id
                )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract only the answer part
            response = response.split("Answer:")[-1].strip()

            # Update chat history
            self.chat_history.append({"query": query, "response": response})

            return response

        except Exception as e:
            return f"I apologize, but I encountered an error while processing your question. Please try rephrasing or asking a shorter question. Error: {str(e)}"


In [None]:
class ChatbotUI:
    def __init__(self, chatbot: LegalChatbot):
        self.chatbot = chatbot
        self.setup_ui()

    def setup_ui(self):
        # Create UI components
        self.output = widgets.Output()
        self.text_input = widgets.Text(
            placeholder='Type your legal question here...',
            layout=widgets.Layout(width='80%')
        )
        self.send_button = widgets.Button(
            description='Send',
            button_style='primary',
            layout=widgets.Layout(width='19%')
        )

        # Create chat history display
        self.chat_area = widgets.HTML(
            layout=widgets.Layout(width='100%', height='400px', border='1px solid black', overflow='auto')
        )

        # Setup layout
        input_box = widgets.HBox([self.text_input, self.send_button])
        self.main_layout = widgets.VBox([self.chat_area, input_box])

        # Setup event handlers
        self.send_button.on_click(self.on_send_button_clicked)
        self.text_input.on_submit(self.on_send_button_clicked)

        # Display UI
        display(self.main_layout)

    def on_send_button_clicked(self, _):
        query = self.text_input.value
        if query.strip():
            # Clear input
            self.text_input.value = ''

            # Get response
            response = self.chatbot.generate_response(query)

            # Update chat display
            current_html = self.chat_area.value
            new_message = f"""
            <div style='margin: 10px; padding: 10px;'>
                <div style='background-color: #e6f3ff; padding: 10px; border-radius: 10px; margin-bottom: 5px;'>
                    <strong>You:</strong> {query}
                </div>
                <div style='background-color: #f0f0f0; padding: 10px; border-radius: 10px;'>
                    <strong>Assistant:</strong> {response}
                </div>
            </div>
            """
            self.chat_area.value = current_html + new_message

In [None]:
chatbot = LegalChatbot(model, tokenizer, index, legal_text)
ui = ChatbotUI(chatbot)

VBox(children=(HTML(value='', layout=Layout(border='1px solid black', height='400px', overflow='auto', width='…