<a href="https://colab.research.google.com/github/mehdihoore/AstraChatbot/blob/main/AstraWithGemini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import json
import uuid
from pathlib import Path
from typing import Any, List, Tuple, Optional
from dataclasses import dataclass
from langchain.vectorstores import VectorStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import Docx2txtLoader  # Updated import path for DocxLoader
from langchain.document_loaders import PyPDFLoader  # Updated import path for PDFLoader
from langchain.document_loaders import TextLoader
from astrapy import DataAPIClient
import re
from google.colab import userdata
from astrapy.constants import VectorMetric
from astrapy.database import Database
from astrapy.collection import Collection
from astrapy.exceptions import CollectionAlreadyExistsException
import time
import google.generativeai as genai

# Constants and configurations
TEXT_FILE_TYPES = ["txt", "docx", "pdf"]
ASTRA_DB_APPLICATION_TOKEN = userdata.get("ASTRA_DB_APPLICATION_TOKEN_GBOOKS")
ASTRA_DB_API_ENDPOINT = userdata.get("ASTRA_DB_API_ENDPOINT_GBOOKS")
GEMINI_API_KEY = userdata.get("GOOGLE_API_KEY").strip()  # New API key for Gemini
GEMINI_EMBEDDING_MODEL_NAME = "models/text-embedding-004"  # Model for Gemini embeddings

@dataclass
class Data:
    content: str = ""
    references: str = ""

class DocumentProcessor:
    def __init__(self, path: str, silent_errors: bool = False):
        self.path = path
        self.silent_errors = silent_errors
        self.status = ""

    def resolve_path(self, path: str) -> str:
        return os.path.abspath(path)

    def load_file(self) -> Tuple[Data, str]:
        if not self.path:
            raise ValueError("Please, upload a file to use this component.")
        resolved_path = self.resolve_path(self.path)

        extension = Path(resolved_path).suffix[1:].lower()
        if extension not in TEXT_FILE_TYPES:
            raise ValueError(f"Unsupported file type: {extension}")

        if extension == "docx":
            loader = Docx2txtLoader(resolved_path)
        elif extension == "pdf":
            loader = PyPDFLoader(resolved_path)
        else:  # Treat as text file
            loader = TextLoader(resolved_path)

        data_list = loader.load()

        # Ensure data_list contains the expected content
        if isinstance(data_list, list) and len(data_list) > 0:
            # Assuming the content of the first item in the list
            data = Data(content=data_list[0].page_content)
            return data, Path(resolved_path).stem
        else:
            return Data(), ""

class RecursiveCharacterTextSplitterComponent:
    def __init__(self, chunk_size: int = 1000, chunk_overlap: int = 200, separators: Optional[List[str]] = None):
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.separators = separators if separators else [".", "\n"]

    def split_text(self, text: str) -> List[str]:
        splitter = RecursiveCharacterTextSplitter(
            separators=self.separators,
            chunk_size=self.chunk_size,
            chunk_overlap=self.chunk_overlap,
        )
        return splitter.split_text(text)

class GeminiEmbeddingsComponent:
    def __init__(self, gemini_api_key: str, model_name: str = "models/text-embedding-004"):
        self.model = model_name
        genai.configure(api_key=gemini_api_key)

    def build_embeddings(self, texts: List[str], expected_dim: int = 768) -> List[List[float]]:
        embeddings = []
        batch_size = 10  # Adjust batch size as needed

        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            try:
                # Fetch embeddings from Gemini API
                result = genai.embed_content(
                    model=self.model,
                    content=batch,
                    task_type="retrieval_document"
                )

                # Extract the embeddings directly from the result
                if 'embedding' in result:
                    for embedding in result['embedding']:
                        if len(embedding) == expected_dim:
                            embeddings.append(embedding)  # Only add embeddings with correct dimensions
                        else:
                            print(f"Warning: Skipping embedding due to incorrect dimension. Expected {expected_dim}, got {len(embedding)}.")

                time.sleep(60)  # Delay to manage rate limits

            except Exception as e:
                print(f"Error generating embeddings: {e}")
                time.sleep(60)  # Delay before retrying

        return embeddings

class AstraDBManager:
    def __init__(self, api_endpoint: str, token: str, collection_name: str, namespace: Optional[str] = None):
        self.api_endpoint = api_endpoint
        self.token = token
        self.collection_name = collection_name
        self.namespace = namespace or "default_namespace"
        self.client = DataAPIClient(token)
        self.database = self.client.get_database(api_endpoint)

    def get_or_create_collection(self, collection_name: str, dimension: int = 768):
        try:
            print(f"Checking for collection {collection_name} in database.")
            collections = self.database.list_collections()
            if collection_name in collections:
                print(f"* Collection {collection_name} already exists.")
                return self.database.get_collection(collection_name)
            else:
                print(f"* Collection {collection_name} does not exist. Creating...")
                collection = self.database.create_collection(
                    name=collection_name,
                    dimension=dimension,
                    metric=VectorMetric.COSINE,
                )
                print(f"* Collection {collection_name} created successfully.")
                return collection
        except CollectionAlreadyExistsException:
            print(f"* Collection {collection_name} already exists. Skipping creation.")
            return self.database.get_collection(collection_name)
        except Exception as e:
            print(f"Error handling collection {collection_name}: {e}")
            raise

    def add_documents(self, embeddings: List[List[float]], doc_name: str, references: str, chunks: List[str]):
        collection = self.get_or_create_collection(collection_name=self.collection_name)

        for index, embedding in enumerate(embeddings):
            if len(embedding) != 768:  # Ensure embedding has the correct dimensions
                print(f"Skipping document {index} due to incorrect embedding dimension.")
                continue

            doc_id = str(uuid.uuid4())
            document = {
                "_id": doc_id,
                "content": chunks[index],
                "$vector": embedding,
                "metadata": {
                    "doc_name": doc_name,
                    "references": references,
                }
            }
            try:
                result = collection.insert_one(document)
                if result.inserted_id:
                    print(f"Inserted document {doc_id}")
                else:
                    print(f"Failed to insert document {doc_id}")
            except Exception as e:
                print(f"Error processing document {doc_id}: {e}")


def extract_references(text: str) -> str:
    # Implement logic to extract references in the format "2-8-7-22"
    pattern = r'\d+-\d+-\d+-\d+'
    matches = re.findall(pattern, text)
    return matches[0] if matches else "No references"

def main(file_paths: List[str]):
    # Initialize components
    text_splitter = RecursiveCharacterTextSplitterComponent()
    embeddings_component = GeminiEmbeddingsComponent(gemini_api_key=GEMINI_API_KEY)
    astradb_manager = AstraDBManager(
        api_endpoint=ASTRA_DB_API_ENDPOINT,
        token=ASTRA_DB_APPLICATION_TOKEN,
        collection_name="googlenashriat"
    )

    for file_path in file_paths:
        print(f"Processing file: {file_path}")
        processor = DocumentProcessor(file_path)

        # Load and process file
        data, doc_name = processor.load_file()
        chunks = text_splitter.split_text(data.content)

        # Embeddings
        embeddings = embeddings_component.build_embeddings(chunks)

        # Extract references from the whole document
        references = extract_references(data.content)

        # Send to AstraDB
        astradb_manager.add_documents(embeddings, doc_name, references, chunks)
        print(f"Finished processing file: {file_path}")

if __name__ == "__main__":
    file_paths = [


        "/content/drive/MyDrive/Mabahes/Mabhas_05.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_06.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_07.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_08.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_09.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_10.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_11.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_12.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_13.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_14.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_15.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_16.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_17.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_18.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_19.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_20.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_21.txt",
        "/content/drive/MyDrive/Mabahes/Mabhas_22.txt",

    ]
    main(file_paths)


Processing file: /content/drive/MyDrive/Mabahes/Mabhas_05.txt
Error generating embeddings: HTTPConnectionPool(host='localhost', port=36743): Read timed out. (read timeout=60.0)


KeyboardInterrupt: 

In [None]:
!pip install astrapy langchain-community openai pypdf python-dotenv docx2txt tiktoken nltk sentence-transformers transformers torch hazm

In [None]:
import uuid
from typing import List
import google.generativeai as genai
from astrapy.database import Database
from astrapy.collection import Collection

class GeminiQueryComponent:
    def __init__(self, gemini_api_key: str, model_name: str = "models/text-embedding-004"):
        self.model = model_name
        genai.configure(api_key=gemini_api_key)

    def build_query_embedding(self, query: str) -> List[float]:
        try:
            result = genai.embed_content(
                model=self.model,
                content=[query],
                task_type="retrieval_document"
            )

            if 'embedding' in result:
                return result['embedding'][0]  # Return the embedding for the query
            else:
                raise ValueError("Unexpected result format. No 'embedding' key found.")
        except Exception as e:
            print(f"Error generating query embedding: {e}")
            return []

class AstraDBQueryManager:
    def __init__(self, api_endpoint: str, token: str, collection_name: str, namespace: Optional[str] = None):
        self.api_endpoint = api_endpoint
        self.token = token
        self.collection_name = collection_name
        self.namespace = namespace or "default_namespace"
        self.client = DataAPIClient(token)
        self.database = self.client.get_database(api_endpoint)
        self.collection = self.database.get_collection(collection_name)

    def query_documents(self, query_embedding: List[float], top_k: int = 5):
        try:
            result = self.collection.search(
                vector=query_embedding,
                top_k=top_k,
                metric=VectorMetric.COSINE
            )

            if result.get('documents'):
                return result['documents']  # Return the documents found
            else:
                print("No documents found for the query.")
                return []
        except Exception as e:
            print(f"Error querying documents: {e}")
            return []

def main_query(query_text: str):
    gemini_query = GeminiQueryComponent(gemini_api_key=GEMINI_API_KEY)
    query_embedding = gemini_query.build_query_embedding(query_text)

    if query_embedding:
        print(f"Query embedding generated successfully. Querying the database...")

        # Query the database for similar documents
        db_query_manager = AstraDBQueryManager(
            api_endpoint=ASTRA_DB_API_ENDPOINT,
            token=ASTRA_DB_APPLICATION_TOKEN,
            collection_name="gnashriat"
        )
        documents = db_query_manager.query_documents(query_embedding)

        if documents:
            print(f"Found {len(documents)} documents matching the query:")
            for doc in documents:
                print(f"Document ID: {doc['_id']}")
                print(f"Content: {doc['content']}")
                print(f"Metadata: {doc['metadata']}")
                print("="*50)
        else:
            print("No documents found.")

if __name__ == "__main__":
    query_text = "شرایط عمومی ساختمان‌های مسکونی، اداری، تجاری و صنعتی چیست؟"
    main_query(query_text)
