In [None]:
!pip install transformers langchain llama-index faiss-cpu chromadb pypdf sentence-transformers unstructured



In [None]:
import torch
import numpy as np
import faiss
from typing import List
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from pypdf import PdfReader
import re

In [None]:
def extract_text_from_pdf(pdf_path):
    reader = PdfReader(pdf_path)
    text = ""
    for page in reader.pages:
        text += page.extract_text() + "\n"
    return text

In [None]:
def split_into_chunks(text: str, chunk_size: int = 500, overlap: int = 100) -> List[str]:
    """Splitting text into overlapping chunks for better context preservation"""
    sentences = re.split(r'(?<=[.!?])\s+', text)
    chunks = []
    current_chunk = []
    current_size = 0

    for sentence in sentences:
        sentence_words = sentence.split()
        sentence_len = len(sentence_words)

        if current_size + sentence_len <= chunk_size:
            current_chunk.extend(sentence_words)
            current_size += sentence_len
        else:

            chunks.append(" ".join(current_chunk))


            overlap_start = max(0, len(current_chunk) - overlap)
            current_chunk = current_chunk[overlap_start:] + sentence_words
            current_size = len(current_chunk)


    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

In [None]:
class ImprovedLegalChatbot:
    def __init__(self, model, tokenizer, legal_text, embedding_model_name="sentence-transformers/all-mpnet-base-v2"):
        self.model = model
        self.tokenizer = tokenizer
        self.legal_text = legal_text
        self.chat_history = []
        self.max_context_length = 1024  # Increased context length


        self.embedding_model = SentenceTransformer(embedding_model_name)
        self.embedding_dimension = self.embedding_model.get_sentence_embedding_dimension()


        self.chunks = []
        self.initialize_index()

    def initialize_index(self, chunk_size: int = 500, overlap: int = 100):
        """Pre-computing chunks and build index once"""
        print("Initializing FAISS index...")


        self.chunks = split_into_chunks(self.legal_text, chunk_size, overlap)


        chunk_embeddings = []
        for chunk in self.chunks:
            embedding = self.embedding_model.encode(chunk, convert_to_tensor=True).cpu().numpy()
            chunk_embeddings.append(embedding)


        chunk_embeddings = np.vstack(chunk_embeddings)


        self.index = faiss.IndexFlatL2(self.embedding_dimension)
        self.index.add(chunk_embeddings)

        print(f"Added {len(self.chunks)} chunks to FAISS index")

    def get_relevant_context(self, query: str, num_chunks: int = 3) -> str:
        """Retrieving most relevant chunks for the given query"""

        query_embedding = self.embedding_model.encode(query, convert_to_tensor=True).cpu().numpy().reshape(1, -1)


        distances, indices = self.index.search(query_embedding, num_chunks)


        relevant_chunks = [self.chunks[i] for i in indices[0]]


        formatted_chunks = []
        for i, chunk in enumerate(relevant_chunks):
            formatted_chunks.append(f"Excerpt {i+1}: {chunk}")

        relevant_text = "\n\n".join(formatted_chunks)
        return self._truncate_context(relevant_text)

    def _truncate_context(self, context: str) -> str:
        """Truncate context to fit within token limit"""
        tokens = self.tokenizer.encode(context)
        if len(tokens) > self.max_context_length:
            tokens = tokens[:self.max_context_length]
            context = self.tokenizer.decode(tokens)
        return context

    def generate_response(self, query: str) -> str:
        try:

            context = self.get_relevant_context(query)


            history_text = ""
            if len(self.chat_history) > 0:
                recent_history = self.chat_history[-2:] if len(self.chat_history) >= 2 else self.chat_history
                for entry in recent_history:
                    history_text += f"User: {entry['query']}\nAssistant: {entry['response']}\n\n"


            prompt = f"""You are a helpful legal assistant. Use the following context from legal documents to answer the question accurately.
            If the context doesn't contain relevant information, say so instead of making up information.
            If you're uncertain, express your uncertainty and note that a legal professional should be consulted.

            Previous conversation:
            {history_text}

            Legal context:
            {context}

            Question: {query}

            Answer:"""


            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.max_context_length)

            with torch.no_grad():
                outputs = self.model.generate(
                    inputs.input_ids,
                    max_new_tokens=512,
                    num_return_sequences=1,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True,
                    repetition_penalty=1.2,
                    pad_token_id=self.tokenizer.eos_token_id
                )

            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)


            response = response.split("Answer:")[-1].strip()


            self.chat_history.append({"query": query, "response": response})

            return response

        except Exception as e:
            return f"I apologize, but I encountered an error while processing your question. Please try rephrasing or asking a shorter question. Error: {str(e)}"

In [None]:
class ImprovedChatbotUI:
    def __init__(self, chatbot):
        self.chatbot = chatbot
        self.setup_ui()

    def setup_ui(self):
        """Same UI setup but with explicit text color to avoid white text issue"""

        from IPython.display import HTML, display
        import ipywidgets as widgets

        self.output = widgets.Output()
        self.text_input = widgets.Text(
            placeholder='Type your legal question here...',
            layout=widgets.Layout(width='80%')
        )
        self.send_button = widgets.Button(
            description='Send',
            button_style='primary',
            layout=widgets.Layout(width='19%')
        )


        self.chat_area = widgets.HTML(
            value="<div style='font-family: Arial, sans-serif; padding: 10px;'><h3>Legal Assistant</h3><p>Ask me questions about legal documents.</p></div>",
            layout=widgets.Layout(width='100%', height='400px', border='1px solid #ccc', overflow='auto')
        )


        input_box = widgets.HBox([self.text_input, self.send_button])
        self.main_layout = widgets.VBox([self.chat_area, input_box])


        self.send_button.on_click(self.on_send_button_clicked)
        self.text_input.on_submit(self.on_send_button_clicked)


        display(self.main_layout)

    def on_send_button_clicked(self, _):
        query = self.text_input.value
        if query.strip():

            self.text_input.value = ''


            response = self.chatbot.generate_response(query)


            current_html = self.chat_area.value
            new_message = f"""
            <div style='margin: 10px; padding: 10px;'>
                <div style='background-color: #e6f3ff; padding: 10px; border-radius: 10px; margin-bottom: 5px; color: black;'>
                    <strong>You:</strong> {query}
                </div>
                <div style='background-color: #f0f0f0; padding: 10px; border-radius: 10px; color: black;'>
                    <strong>Assistant:</strong> {response}
                </div>
            </div>
            """
            self.chat_area.value = current_html + new_message


def create_improved_chatbot(pdf_path, model_name="TinyLlama/TinyLlama-1.1B-chat-v1.0"):
    # Extract text
    legal_text = extract_text_from_pdf(pdf_path)

    # Load model & tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    # Initialize chatbot
    chatbot = ImprovedLegalChatbot(model, tokenizer, legal_text)
    ui = ImprovedChatbotUI(chatbot)

    return chatbot, ui

In [None]:
chatbot, ui = create_improved_chatbot("/content/constitution.pdf")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

1_Pooling%2Fconfig.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Initializing FAISS index...
Added 388 chunks to FAISS index


VBox(children=(HTML(value="<div style='font-family: Arial, sans-serif; padding: 10px;'><h3>Legal Assistant</h3…

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
