In [26]:
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from bs4 import BeautifulSoup
from transformers import AutoTokenizer
import re
import pandas as pd
from io import StringIO
from tqdm.notebook import tqdm
import time

model_path = '/Users/hissain/git/github/models/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(model_path)

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def init_driver():
    options = webdriver.ChromeOptions()
    options.add_argument('--headless')
    options.add_argument('--no-sandbox')
    options.add_argument('--disable-dev-shm-usage')
    service = Service()
    return webdriver.Chrome(service=service, options=options)

def extract_text_from_url(url, driver):
    driver.get(url)
    time.sleep(1)
    soup = BeautifulSoup(driver.page_source, "html.parser")
    soup = soup.body
    if soup is None: return "" 
    for script in soup(["script", "style", "footer", "header", "nav", "a", "table"]):
        script.decompose()
    return soup.get_text(separator=" ")

def prepare_table_for_rag_by_token_count(url, df, window_len_tokens):
    chunks, current_chunk = [], ""
    header_text = "Table Headers: " + ", ".join(map(str, df.columns))
    header_tokens = tokenizer.encode(header_text)
    if len(header_tokens) > window_len_tokens:
        chunks.extend([{"text": header_text[i:i+window_len_tokens], "url": url} for i in range(0, len(header_text), window_len_tokens)])
    else:
        current_chunk += header_text + "\n"
    
    for index, row in df.iterrows():
        row_text = [f"{col}: {str(val)}" for col, val in row.items()]
        row_str = " | ".join(row_text)
        row_tokens = tokenizer.encode(row_str)
        
        if len(tokenizer.encode(current_chunk)) + len(row_tokens) > window_len_tokens:
            chunks.append({"text": current_chunk.strip(), "url": url})
            current_chunk = header_text + "\n" + row_str + "\n"
        else:
            current_chunk += row_str + "\n"

    if current_chunk:
        chunks.append({"text": current_chunk.strip(), "url": url})
    
    return chunks

def extract_table_from_url(url, driver):
    driver.get(url)
    time.sleep(1)
    
    all_chunks = []
    
    html_source = driver.page_source
    html_file = StringIO(html_source)
    tables = pd.read_html(html_file)
    
    for t in tables:
        chunks = prepare_table_for_rag_by_token_count(url, t, 512)
        all_chunks.extend(chunks)
        
    return all_chunks

def clean_text(text):
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s.,!?\'"()-]', '', text)
    return text.strip().lower()

def split_sentences(text):
    return re.split(r'(?<=[.!?]) +', text)

def count_tokens(text):
    return len(tokenizer.encode(text))

def partition_sentences(sentences, url, max_tokens=512, overlap=1):
    chunks, current_chunk = [], []
    current_tokens = 0

    for i, sentence in enumerate(sentences):
        sentence_tokens = count_tokens(sentence)
        
        if current_tokens + sentence_tokens > max_tokens:
            chunks.append({"text": " ".join(map(str, current_chunk)), "url": url})
            current_chunk = current_chunk[-overlap:]
            current_tokens = count_tokens(" ".join(map(str, current_chunk)))

        current_chunk.append(sentence)
        current_tokens += sentence_tokens

    if current_chunk:
        chunks.append({"text": " ".join(map(str, current_chunk)), "url": url})

    return chunks

def process_urls(urls):
    driver = init_driver()
    all_chunks = []
    
    for url in tqdm(urls, desc="Processing URLs"):
        try:
            raw_text = extract_text_from_url(url, driver)
            clean_text_content = clean_text(raw_text)
            sentences = split_sentences(clean_text_content)
            chunks = partition_sentences(sentences, url, max_tokens=512, overlap=1)
            all_chunks.extend(chunks)
            
            table_chunks = extract_table_from_url(url, driver)
            all_chunks.extend(table_chunks)
            
        except Exception as e:
            print(f"Failed to process {url}: {e}")
    
    return all_chunks

urls = [
    "https://en.wikipedia.org/wiki/List_of_wars_by_death_toll",
]

rag_chunks = process_urls(urls)
print(f"Total chunks: {len(rag_chunks)}")

Processing URLs:   0%|          | 0/1 [00:00<?, ?it/s]

Total chunks: 38


In [28]:
for i, chunk in enumerate(rag_chunks[15:20]):
    print(f"Chunk {i+1}:\nText: {chunk['text']}\nURL: {chunk['url']}\n")

Chunk 1:
Text: . bercovitch, jacob jackson, richard (1997). . congressional quarterly. . . www.ictj.org . 2009-01-01 . retrieved 2024-09-30 . . 2011-05-22. archived from on 2011-05-22 . retrieved 2024-09-30 . . scribd . retrieved 2024-10-03 . marley, david (1998). . abc-clio. . farcau, bruce w. (1996-05-23). . bloomsbury academic. . web, apunto. . portal guarani (in european spanish) . retrieved 2024-10-01 . tinker-salas, miguel (2015). . oxford university press. . mwakikagile, godfrey (2014-04-21). . new africa press. . arrian (1884). . cornell university library. london, hodder and stoughton. arrian (2018-04-10). . ozymandias press. . staff, historynet (2007-09-17). . historynet . retrieved 2024-10-03 . . herre, bastian rodés-guirao, lucas roser, max (2024-03-20). . our world in data . herre, bastian roser, max (2024-07-15). . our world in data . further reading   (2011). . penguin books. . pp. 832. (see also ) levy, jack s. (1983). . university press of kentucky, usa. . external lin

In [29]:
import numpy as np
from qdrant_client import QdrantClient, models
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer
from IPython.display import display, clear_output, Markdown
import requests
import json
import asyncio

session = requests.Session()
session.headers.update({"Connection": "keep-alive", "Content-Type": "application/json"})

qdrant_url = "http://localhost:6333"
collection_name = "github_collection"

ollama_url_inf = "http://localhost:11434/api/show"
ollama_url_emb = "http://localhost:11434/api/embeddings"
ollama_url_gen = "http://localhost:11434/api/generate"
ollama_model_name = "llama3.2:latest"

client = QdrantClient(url=qdrant_url)
embedding_model = SentenceTransformer(model_path)

def get_embedding(text):
    return embedding_model.encode(text)

def create_collection(dimension):
    try:
        client.delete_collection(collection_name=collection_name)
    except Exception:
        pass

    client.create_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
    )
    
def upsert_points_with_metadata(embeddings, chunks):
    points = [
        models.PointStruct(id=i, vector=embedding.tolist(), payload={"text": chunk["text"], "url": chunk["url"]})
        for i, (embedding, chunk) in enumerate(zip(embeddings, chunks))
    ]
    client.upsert(collection_name=collection_name, points=points)

def store_in_qdrant_with_metadata(chunks):
    dimension = 384  # Dimension for 'all-MiniLM-L6-v2'
    create_collection(dimension)
    embeddings = [get_embedding(chunk["text"]) for chunk in tqdm(chunks, desc="Generating embeddings")]
    upsert_points_with_metadata(embeddings, chunks)

def search_points_with_metadata(query_embedding, k=3):
    search_result = client.search(
        collection_name=collection_name,
        query_vector=query_embedding.tolist(),
        limit=k,
        with_payload=True
    )
    return [{"text": hit.payload["text"], "url": hit.payload["url"]} for hit in search_result]

def ask(query, k=3, p=False):
    
    query_embedding = get_embedding(query)
    retrieved_docs = search_points_with_metadata(query_embedding, k)
    
    combined_docs = "\n\n".join([f"Source: {doc['url']}\n{doc['text']}" for doc in retrieved_docs])
    inst = "Instruction: If you do not find the answer in the context, just say you don't know."
    rag_prompt = f"{inst}\n\nContext:\n{combined_docs}\n\nQuery: {query}\nAnswer:"
    if p:
        print(rag_prompt)
        
    payload = {"model": ollama_model_name, "prompt": rag_prompt, "stream": True}
    headers = {"Content-Type": "application/json"}

    response_text = ""
    buffer = ""

    response = session.post(ollama_url_gen, headers=headers, data=json.dumps(payload), stream=True)

    # Process the response content as it arrives
    if response.status_code == 200:
        for chunk in response.iter_content(chunk_size=None):
            try:
                data = json.loads(chunk.decode('utf-8'))
                content = data.get("response", "")
                buffer += content

                # Display output every few characters for real-time effect
                if len(buffer) > 10:
                    response_text += buffer
                    clear_output(wait=True)
                    display(Markdown(response_text))
                    buffer = ""
                    
            except json.JSONDecodeError:
                continue

        # Display any remaining buffered content
        response_text += buffer
        clear_output(wait=True)
        display(Markdown(response_text))
    else:
        print("Request failed:", response.status_code, response.text)

    return response_text

try:
    store_in_qdrant_with_metadata(rag_chunks)
    print(f'Stored {len(rag_chunks)} chunks')
except Exception as e:
    print(f"Error storing in Qdrant: {e}")


Generating embeddings:   0%|          | 0/38 [00:00<?, ?it/s]

Stored 38 chunks


In [30]:
_ = ask("How many death in Colombian conflict?")

0.45 million[128]

In [7]:
ollama_url_chat = "http://localhost:11434/api/chat"

chat_history = []

def chat(query, k=2, p=False, stream=True):
    global chat_history
    
    query_embedding = get_embedding(query)
    retrieved_docs = search_points_with_metadata(query_embedding, k)
    
    combined_docs = "\n\n".join([f"Source: {doc['url']}\n{doc['text']}" for doc in retrieved_docs])
    inst = "Instruction: If you do not find the answer in the context, just say you don't know."
    rag_prompt = f"{inst}\n\nContext:\n{combined_docs}\n\nQuery: {query}\nAnswer:"
    if p:
        print(rag_prompt)
        
    chat_history.append({"role": "user", "content": rag_prompt})
    payload = {"model": ollama_model_name, "messages": chat_history, "stream": stream}
    headers = {"Connection": "keep-alive", "Content-Type": "application/json"}

    response_text = ""
    buffer = ""

    response = session.post(ollama_url_chat, data=json.dumps(payload), stream=stream)

    # Process the response content as it arrives
    if response.status_code == 200:
        for chunk in response.iter_content(chunk_size=None):
            try:
                data = json.loads(chunk.decode('utf-8'))
                content = data.get("message", {}).get("content", "")
                buffer += content

                # Display output every few characters for real-time effect
                if len(buffer) > 10:
                    response_text += buffer
                    clear_output(wait=True)
                    display(Markdown(response_text))
                    buffer = ""
                    
            except json.JSONDecodeError:
                continue

        # Display any remaining buffered content
        response_text += buffer
        clear_output(wait=True)
        display(Markdown(response_text))
    else:
        print("Request failed:", response.status_code, response.text)

    return response_text