In [1]:
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from bs4 import BeautifulSoup
from transformers import AutoTokenizer
import re
import pandas as pd
from io import StringIO
from tqdm.notebook import tqdm
import time

model_path = '/Users/hissain/git/github/models/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(model_path)

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def init_driver():
    options = webdriver.ChromeOptions()
    options.add_argument('--headless')
    options.add_argument('--no-sandbox')
    options.add_argument('--disable-dev-shm-usage')
    service = Service()
    return webdriver.Chrome(service=service, options=options)

def extract_text_from_url(url, driver):
    driver.get(url)
    time.sleep(1)
    soup = BeautifulSoup(driver.page_source, "html.parser")
    soup = soup.body
    if soup is None: return "" 
    for script in soup(["script", "style", "footer", "header", "nav", "a", "table"]):
        script.decompose()
    return soup.get_text(separator=" ")

def extract_and_prepare_table_chunks(url, driver, tokenizer, max_tokens=512):
    driver.get(url)
    time.sleep(1)

    html_source = driver.page_source
    tables = pd.read_html(StringIO(html_source))
    all_chunks = []

    for df in tables:
        chunks, current_chunk = [], ""
        header_text = "Table Headers: " + ", ".join(map(str, df.columns))
        header_tokens = tokenizer.encode(header_text)

        if len(header_tokens) > max_tokens:
            chunks.extend(
                [{"text": header_text[i:i+max_tokens], "url": url} for i in range(0, len(header_text), max_tokens)]
            )
        else:
            current_chunk += header_text + "\n"

        for _, row in df.iterrows():
            row_str = " | ".join([f"{col}: {str(val)}" for col, val in row.items()])
            row_tokens = tokenizer.encode(row_str)

            if len(tokenizer.encode(current_chunk)) + len(row_tokens) > max_tokens:
                chunks.append({"text": current_chunk.strip(), "url": url})
                current_chunk = header_text + "\n" + row_str + "\n"
            else:
                current_chunk += row_str + "\n"

        if current_chunk:
            chunks.append({"text": current_chunk.strip(), "url": url})

        all_chunks.extend(chunks)

    return all_chunks

def clean_text(text):
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^\w\s.,!?\'"()-]', '', text)
    return text.strip().lower()

def split_sentences(text):
    return re.split(r'(?<=[.!?]) +', text)

def count_tokens(text):
    return len(tokenizer.encode(text))

def partition_sentences(sentences, url, max_tokens=400, overlap=1):
    chunks, current_chunk = [], []
    current_tokens = 0

    for i, sentence in enumerate(sentences):
        sentence_tokens = count_tokens(sentence)
        
        if current_tokens + sentence_tokens > max_tokens:
            chunks.append({"text": " ".join(map(str, current_chunk)), "url": url})
            current_chunk = current_chunk[-overlap:]
            current_tokens = count_tokens(" ".join(map(str, current_chunk)))

        current_chunk.append(sentence)
        current_tokens += sentence_tokens

    if current_chunk:
        chunks.append({"text": " ".join(map(str, current_chunk)), "url": url})

    return chunks

def process_urls(urls):
    driver = init_driver()
    all_chunks = []
    
    for url in tqdm(urls, desc="Processing URLs"):
        try:
            raw_text = extract_text_from_url(url, driver)
            clean_text_content = clean_text(raw_text)
            sentences = split_sentences(clean_text_content)
            chunks = partition_sentences(sentences, url, max_tokens=400, overlap=1)
            all_chunks.extend(chunks)
            
            table_chunks = extract_and_prepare_table_chunks(url, driver, tokenizer)
            all_chunks.extend(table_chunks)
            
        except Exception as e:
            print(f"Failed to process {url}: {e}")
    
    return all_chunks

urls = [
    "https://en.wikipedia.org/wiki/List_of_wars_by_death_toll",
]

rag_chunks = process_urls(urls)
print(f"Total chunks: {len(rag_chunks)}")

Processing URLs:   0%|          | 0/1 [00:00<?, ?it/s]

Total chunks: 44


In [2]:
for i, chunk in enumerate(rag_chunks[15:20]):
    print(f"Chunk {i+1}:\nText: {chunk['text']}\nURL: {chunk['url']}\n")

Chunk 1:
Text: . bbc news . 2014-09-23 . retrieved 2024-10-03 . . . levy, jack s. (1983). . university press of kentucky. . . clodfelter, micheal (2008). . internet archive. jefferson, n.c. mcfarland. . nolan, cathal j. (2006). . internet archive. westport, conn. greenwood press. . momodu, samuel (2016-07-25). . retrieved 2024-10-02 . . remilitari.com . retrieved 2024-10-02 . (pdf) . yi, ki-baek (1984). . internet archive. cambridge, mass. published for the harvard-yenching institute by harvard university press. . shi, li. . deeplogic. salvador, antonio caridad (2018-05-23). . cuadernos de historia contemporánea (in spanish). 40  149167. . . . the costs of war . retrieved 2024-10-03 . . www.pbs.org . retrieved 2024-09-29 . . www.forcesnews.com . 2021-02-28 . retrieved 2024-09-29 . gillespie, caitlin c. (2018-01-15). . oxford university press.
URL: https://en.wikipedia.org/wiki/List_of_wars_by_death_toll

Chunk 2:
Text: oxford university press. . crane, nicholas (2016-10-13). . orion. .

In [3]:
import numpy as np
from qdrant_client import QdrantClient, models
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer
from IPython.display import display, clear_output, Markdown
import requests
import json
import asyncio

session = requests.Session()
session.headers.update({"Connection": "keep-alive", "Content-Type": "application/json"})

qdrant_url = "http://localhost:6333"
collection_name = "github_collection"

ollama_url_inf = "http://localhost:11434/api/show"
ollama_url_emb = "http://localhost:11434/api/embeddings"
ollama_url_gen = "http://localhost:11434/api/generate"
ollama_model_name = "llama3.2:latest"

client = QdrantClient(url=qdrant_url)
embedding_model = SentenceTransformer(model_path)

def get_embedding(text):
    return embedding_model.encode(text)

def create_collection(dimension):
    try:
        client.delete_collection(collection_name=collection_name)
    except Exception:
        pass

    client.create_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
    )
    
def upsert_points_with_metadata(embeddings, chunks):
    points = [
        models.PointStruct(id=i, vector=embedding.tolist(), payload={"text": chunk["text"], "url": chunk["url"]})
        for i, (embedding, chunk) in enumerate(zip(embeddings, chunks))
    ]
    client.upsert(collection_name=collection_name, points=points)

def store_in_qdrant_with_metadata(chunks):
    dimension = 384  # Dimension for 'all-MiniLM-L6-v2'
    create_collection(dimension)
    embeddings = [get_embedding(chunk["text"]) for chunk in tqdm(chunks, desc="Generating embeddings")]
    upsert_points_with_metadata(embeddings, chunks)

def search_points_with_metadata(query_embedding, k=3):
    search_result = client.search(
        collection_name=collection_name,
        query_vector=query_embedding.tolist(),
        limit=k,
        with_payload=True
    )
    return [{"text": hit.payload["text"], "url": hit.payload["url"]} for hit in search_result]

try:
    store_in_qdrant_with_metadata(rag_chunks)
    print(f'Stored {len(rag_chunks)} chunks')
except Exception as e:
    print(f"Error storing in Qdrant: {e}")


Generating embeddings:   0%|          | 0/44 [00:00<?, ?it/s]

Stored 44 chunks


In [4]:
def ask(query, k=3, p=False):
    
    query_embedding = get_embedding(query)
    retrieved_docs = search_points_with_metadata(query_embedding, k)
    
    combined_docs = "\n\n".join([f"Source: {doc['url']}\n{doc['text']}" for doc in retrieved_docs])
    inst = "Instruction: If you do not find the answer in the CONTEXT, just say you don't know."
    rag_prompt = f"{inst}\n\n<CONTEXT>\n{combined_docs}\n</CONTEXT>\n\nQuery: {query}\n"
    if p:
        print(rag_prompt)
        
    payload = {"model": ollama_model_name, "prompt": rag_prompt, "stream": True}
    headers = {"Content-Type": "application/json"}

    response_text = ""
    if p:
        response_text = rag_prompt
    buffer = ""

    response = session.post(ollama_url_gen, headers=headers, data=json.dumps(payload), stream=True)

    # Process the response content as it arrives
    if response.status_code == 200:
        for chunk in response.iter_content(chunk_size=None):
            try:
                data = json.loads(chunk.decode('utf-8'))
                content = data.get("response", "")
                buffer += content

                # Display output every few characters for real-time effect
                if len(buffer) > 10:
                    response_text += buffer
                    clear_output(wait=True)
                    display(Markdown(response_text))
                    buffer = ""
                    
            except json.JSONDecodeError:
                continue

        # Display any remaining buffered content
        response_text += buffer
        clear_output(wait=True)
        display(Markdown(response_text))
    else:
        print("Request failed:", response.status_code, response.text)

    return response_text

In [5]:
_ = ask("Bangladesh Liberation War data?", p=True)

Instruction: If you do not find the answer in the CONTEXT, just say you don't know.

<CONTEXT>
Source: https://en.wikipedia.org/wiki/List_of_wars_by_death_toll
Table Headers: War, Death range, Date, Combatants, Location
War: South Sudanese Civil War | Death range: 0.38 million[143] | Date: 2013–2020 | Combatants: South Sudan vs. SPLM-IO, Nuer White Army, and SSDM | Location: South Sudan
War: Yemeni civil war | Death range: 0.15–0.37 million[144][145] | Date: 2014–present | Combatants: Multiple sides | Location: Yemen
War: Boko Haram insurgency | Death range: 0.03–0.35 million[146] | Date: 2009–present | Combatants: Multinational Joint Task Force vs. Boko Haram | Location: Nigeria
War: Franco-Dutch War | Death range: 0.34 million[147] | Date: 1672–1678 | Combatants: Kingdom of France vs. Dutch Republic | Location: Western Europe
War: Ottoman–Venetian wars | Death range: 0.34 million[148][149] | Date: 1415–1718 | Combatants: Ottoman Empire vs. Holy League | Location: Mediterranean Sea, Greece and Cyprus
War: Liberian Civil Wars and Sierra Leone Civil War | Death range: 0.3–0.32 million[150][151][152] | Date: 1989–2003 | Combatants: Liberian government, Revolutionary United Front vs. National Patriotic Front of Liberia, Liberians United for Reconciliation and Democracy, Movement for Democracy in Liberia, Sierra Leone | Location: West Africa
War: Goguryeo–Sui War | Death range: 0.3 million[153][154] | Date: 598–614 | Combatants: Sui Dynasty vs. Goguryeo | Location: Manchuria and Korean Peninsula
War: Carlist Wars | Death range: 0.3 million[155] | Date: 1833–1876 | Combatants: Carlists vs. Liberals and Republicans | Location: Iberian Peninsula
War: Iraqi conflict | Death range: 0.27–0.3 million[156] | Date: 2003–2017 | Combatants: Multiple sides | Location: Levant
War: Gulf War | Death range: 0.17–0.3 million[157][158] | Date: 1990–1991[e] | Combatants: Kuwait and the United States-led coalition vs. Iraq | Location: Kuwait and Iraq

Source: https://en.wikipedia.org/wiki/List_of_wars_by_death_toll
Table Headers: War, Death range, Date, Combatants, Location
War: Sri Lankan Civil War | Death range: 0.08–0.17 million[200][201] | Date: 1983[f]–2009 | Combatants: Sri Lankan government vs. Separatist Liberation Tigers of Tamil Eelam | Location: Sri Lanka
War: Russo-Japanese War | Death range: 0.12–0.16 million[203] | Date: 1904–1905 | Combatants: Empire of Japan vs. Russian Empire | Location: East Asia
War: Sudanese civil war (2023–present) | Death range: 0.15 million[204][205] | Date: 2023–present | Combatants: Sudan and allies vs. Rapid Support Forces and allies | Location: Sudan
War: Algerian Civil War | Death range: 0.15 million[206] | Date: 1992–2002 | Combatants: Multiple sides | Location: North Africa
War: Arab-Israeli conflict | Death range: 0.15 million[207][208][209][210] | Date: 1948[g]–present | Combatants: Israel vs. Arab League, Iran, Hezbollah, Hamas, and the Houthi movement | Location: Levant
War: Lebanese Civil War | Death range: 0.12–0.15 million[212][213][214] | Date: 1975–1990 | Combatants: Multiple sides | Location: Levant
War: Greek Civil War | Death range: 0.08–0.15 million[215][216] | Date: 1946–1949 | Combatants: Kingdom of Greece vs. Provisional Democratic Government | Location: Balkans and Peloponnese Peninsula
War: Yugoslav Wars | Death range: 0.13–0.14 million[217][218] | Date: 1991–2001 | Combatants: Separatist forces and NATO vs. Socialist Federal Republic of Yugoslavia, later Federal Republic of Yugoslavia | Location: Balkans
War: Irish Nine Year's War | Death range: 0.13 million[219] | Date: 1593–1603 | Combatants: Kingdom of England vs. Irish rebels | Location: Ireland
War: Chaco War | Death range: 0.08–0.13 million[220][221][222] | Date: 1932–1935 | Combatants: Paraguay vs. Bolivia | Location: Paraguay and Bolivia

Source: https://en.wikipedia.org/wiki/List_of_wars_by_death_toll
Table Headers: War, Death range, Date, Combatants, Location
War: World War II | Death range: 50–85 million[4][5][6] | Date: 1939–1945 | Combatants: Allied Powers vs. Axis Powers | Location: Global
War: Mongol invasions and conquests | Death range: 20–60 million[7][8][9][10] | Date: 1207–1405 | Combatants: Mongol Empire vs. various states in Eurasia | Location: Asia and Europe
War: Three Kingdoms | Death range: 34 million[10] | Date: 220–280 | Combatants: Multiple sides | Location: China
War: Taiping Rebellion | Death range: 20–30 million[11][12] | Date: 1850–1864 | Combatants: Qing Dynasty vs. Taiping Heavenly Kingdom | Location: China
War: World War I | Death range: 15–30 million[13][14] | Date: 1914–1918 | Combatants: Allied Powers vs. Central Powers | Location: Global
War: Manchu Conquest of China | Death range: 25 million[15][16] | Date: 1618–1683 | Combatants: Manchu vs. Ming Dynasty | Location: China
War: Conquests of Timur | Death range: 7–20 million[10] | Date: 1369–1405 | Combatants: Timurid Empire vs. various states in Asia | Location: Central Asia, West Asia, and South Asia
War: An Lushan rebellion | Death range: 13 million[17] | Date: 754–763 | Combatants: Tang Dynasty and Uyghur Khaganate vs. Yan Dynasty | Location: China
War: Thirty Years' War | Death range: 4–12 million[18] | Date: 1618–1648 | Combatants: Anti-Imperial Alliance vs. Imperial Alliance | Location: Europe
War: Spanish conquest of Mexico | Death range: 10.5 million[19] | Date: 1519–1530 | Combatants: Spanish Empire and allies vs. Aztec Empire and allies | Location: Mexico
War: Spanish conquest of the Inca Empire | Death range: 10 million[20] | Date: 1533–1572 | Combatants: Spanish Empire vs. Inca Empire | Location: South America
</CONTEXT>

Query: Bangladesh Liberation War data?
I don't know.

In [54]:
ollama_url_chat = "http://localhost:11434/api/chat"

chat_history = []

def chat(query, k=2, p=False, stream=True):
    global chat_history
    
    query_embedding = get_embedding(query)
    retrieved_docs = search_points_with_metadata(query_embedding, k)
    
    combined_docs = "\n\n".join([f"Source: {doc['url']}\n{doc['text']}" for doc in retrieved_docs])
    inst = "Instruction: If you do not find the answer in the context, just say you don't know."
    rag_prompt = f"{inst}\n\n<context>\n{combined_docs}\n</context>\n\nQuery: {query}\nAnswer:"
    if p:
        print(rag_prompt)
        
    chat_history.append({"role": "user", "content": rag_prompt})
    payload = {"model": ollama_model_name, "messages": chat_history, "stream": stream}
    headers = {"Connection": "keep-alive", "Content-Type": "application/json"}

    response_text = ""
    buffer = ""

    response = session.post(ollama_url_chat, data=json.dumps(payload), stream=stream)

    # Process the response content as it arrives
    if response.status_code == 200:
        for chunk in response.iter_content(chunk_size=None):
            try:
                data = json.loads(chunk.decode('utf-8'))
                content = data.get("message", {}).get("content", "")
                buffer += content

                # Display output every few characters for real-time effect
                if len(buffer) > 10:
                    response_text += buffer
                    clear_output(wait=True)
                    display(Markdown(response_text))
                    buffer = ""
                    
            except json.JSONDecodeError:
                continue

        # Display any remaining buffered content
        response_text += buffer
        clear_output(wait=True)
        display(Markdown(response_text))
    else:
        print("Request failed:", response.status_code, response.text)

    return response_text

In [32]:
_ = chat("How many death in Colombian conflict?")

The Colombian conflict had approximately 0.45 million deaths.

In [33]:
_ = chat("Whats the reference of this information?")

I don't know the specific references for this particular information on the Colombian conflict page. The sources listed in the provided context are:

1. bercovitch, jacob jackson, richard (1997)
2. congressional quarterly
3. ictj.org (2009-01-01)
4. marley, david (1998)
5. abc-clio
6. farcau, bruce w. (1996-05-23)
7. bloomsbury academic
8. portal guarani (in european spanish)
9. tinker-salas, miguel (2015)
10. oxford university press
11. mwakikagile, godfrey (2014-04-21)
12. new africa press
13. arrian (1884)
14. cornell university library
15. london, hodder and stoughton
16. arrian (2018-04-10)
17. ozymandias press
18. staff, historynet (2007-09-17)
19. historynet
20. herre, bastian rodés-guirao, lucas roser, max (2024-03-20)
21. our world in data
22. herre, bastian roser, max (2024-07-15)
23. our world in data

However, I couldn't find a single reference that directly cites the death toll of 0.45 million for the Colombian conflict.

In [34]:
_ = chat("You have been provided the source URL as well in the context. No?")

Yes, you have been provided both the full URL and just the short version of it in the context.

In [35]:
_ = chat("Thats the source I was asking for.")

The source URL has indeed been provided in the context.

In [36]:
_ = chat("Tell me whats the source?")

The source is https://en.wikipedia.org/wiki/List_of_wars_by_death_toll, retrieved on September 30, 2024.