In [7]:
import requests
from bs4 import BeautifulSoup
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from urllib.parse import urljoin
import time
import faiss
import numpy as np
from tqdm.notebook import tqdm
import json

# Configure Selenium Chrome options
chrome_options = Options()
chrome_options.add_argument('--headless')  # Run in headless mode
chrome_options.add_argument('--disable-gpu')
chrome_options.add_argument('--no-sandbox')
chrome_options.add_argument('--disable-dev-shm-usage')

# Initialize the Chrome driver
service = Service()
driver = webdriver.Chrome(service=service, options=chrome_options)

def get_visitable_links(base_url):
    """Scrapes all valid links from the base_url."""
    driver.get(base_url)
    time.sleep(3)
    soup = BeautifulSoup(driver.page_source, 'html.parser')

    links = []
    for a_tag in soup.find_all('a', href=True):
        full_url = urljoin(base_url, a_tag['href'])
        if '#' not in full_url and '%' not in full_url and full_url.startswith(base_url) and full_url not in links:
            links.append(full_url)
    return links

def scrape_text_from_url(url):
    """Scrapes the textual content from a given URL."""
    response = requests.get(url)
    if response.status_code != 200:
        print(f"Failed to retrieve {url}")
        return ""
    soup = BeautifulSoup(response.text, 'html.parser')
    
    # Use get_text() without strip=True to preserve spaces between inline elements
    text_content = ' '.join([p.get_text() for p in soup.find_all('p')])
    
    # Optionally, normalize excessive whitespace
    text_content = ' '.join(text_content.split())
    
    return text_content

def partition_text(text, max_length=512):
    """Partitions the text into smaller parts to feed into RAG."""
    sentences = text.split('. ')
    partitions = []
    current_part = []
    current_length = 0
    
    for sentence in sentences:
        current_length += len(sentence.split())
        current_part.append(sentence)
        
        if current_length > max_length:
            partitions.append('. '.join(current_part))
            current_part = []
            current_length = 0

    if current_part:
        partitions.append('. '.join(current_part))

    return partitions


ollama_url_emb = "http://localhost:11434/api/embeddings"
ollama_url_gen = "http://localhost:11434/api/generate"

def get_embedding(text):
    payload = {
        "model": "llama3.2:latest",
        "prompt": text
    }
    
    headers = {"Content-Type": "application/json"}
    response = requests.post(ollama_url_emb, headers=headers, data=json.dumps(payload))
    
    if response.status_code == 200:
        result = response.json()
        embedding = np.array(result['embedding'])
        print(f"Embedding dimension: {embedding.shape}")  # Add this line to check the embedding shape
        return embedding
    else:
        print(f"Error from Ollama: {response.status_code}")
        return np.zeros(768)  # Return zero vector on error (adjust dimension based on model)

def store_in_faiss(partitions, embedding_func):
    """Stores document embeddings in a FAISS index."""
    embedding_sample = embedding_func(partitions[0])  # Get one embedding to check dimension
    dimension = embedding_sample.shape[0]  # Automatically determine dimension
    print(f"embedding_sample dimension: {dimension}")
    index = faiss.IndexFlatL2(dimension)  # L2 distance

    doc_vectors = []
    doc_ids = []
    
    for i, partition in enumerate(partitions):
        print(f"Embedding partition {i+1} of {len(partitions)}")  # Debugging line to trace embedding calls
        embedding = embedding_func(partition)  # Calling get_embedding()
        index.add(np.array([embedding]))  # Add to FAISS index
        doc_vectors.append(embedding)
        doc_ids.append(i)
    
    return index, doc_ids

def retrieve_with_rag(query, faiss_index, doc_ids):
    query_embedding = get_embedding(query)
    print(f'embedding length: {len(query_embedding)}')
    
    # Search in FAISS index
    print("Performing FAISS search...")
    D, I = faiss_index.search(np.array([query_embedding]), k=2)  # Retrieve top-1 closest documents
    print(f"FAISS search done. D: {D}, I: {I}")

    # Check if FAISS returned valid results
    if len(I) == 0 or len(I[0]) == 0:
        print("No documents found.")
        return
    
    # For simplicity, concatenate the closest documents and feed to RAG
    retrieved_docs = []
    for i in I[0]:
        if i >= len(partitions):
            print(f"Index {i} out of bounds for partition list.")
            continue
        doc_id = doc_ids[i]
        retrieved_docs.append(partitions[doc_id])  # Get the document text from partition list
    
    combined_docs = "\n".join(retrieved_docs)
    print(f"Retrieved docs: {combined_docs}")
    
    # Perform RAG with retrieved docs
    rag_prompt = f"Context:\n{combined_docs}\n\nQuery: {query}\nAnswer:"
    print(f"RAG Prompt: {rag_prompt}")
    
    response = requests.post(ollama_url_gen, headers={"Content-Type": "application/json"},
                             data=json.dumps({"model": "llama3.2:latest", "prompt": rag_prompt}))
    
    print(response)
    if response.status_code == 200:
        return response.json()["completion"]
    else:
        return f"Error from Ollama: {response.status_code}"

    
# Main logic to scrape and partition data
base_url = 'https://en.wikipedia.org/wiki/Bangladesh'
valid_links = get_visitable_links(base_url)[1:4]

print(f"Found {len(valid_links)} valid links. Scraping the content...")

# Scrape text content from all valid links
all_text = ""
for link in tqdm(valid_links):
    text_content = scrape_text_from_url(link)
    all_text += text_content + "\n\n"
    
# Partition the text
partitions = partition_text(all_text, max_length=512)
print(f'partition length: {len(partitions)}')

# Store the partitions in FAISS
faiss_index, doc_ids = store_in_faiss(partitions, embedding_func=get_embedding)

# Query and retrieval with RAG
query = "Where is the location of Bangladesh?"
rag_response = retrieve_with_rag(query, faiss_index, doc_ids)

print(f"RAG Response:\n{rag_response}")

# Close the Selenium browser
#driver.quit()

Found 3 valid links. Scraping the content...


  0%|          | 0/3 [00:00<?, ?it/s]

partition length: 4
Embedding dimension: (3072,)
embedding_sample dimension: 3072
Embedding partition 1 of 4
Embedding dimension: (3072,)
Embedding partition 2 of 4
Embedding dimension: (3072,)
Embedding partition 3 of 4
Embedding dimension: (3072,)
Embedding partition 4 of 4
Embedding dimension: (3072,)


TypeError: object of type 'IndexFlatL2' has no len()