In [140]:
import re
from selenium import webdriver
from selenium.webdriver.chrome.service import Service
from bs4 import BeautifulSoup
import pandas as pd
from transformers import AutoTokenizer
import os
import time
from io import StringIO
import logging
from tqdm.notebook import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_path = '/Users/hissain/git/github/models/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Define maximum token length per chunk
max_token_length = 480

def init_driver():
    options = webdriver.ChromeOptions()
    options.add_argument('--headless')
    options.add_argument('--no-sandbox')
    options.add_argument('--disable-dev-shm-usage')
    service = Service()
    return webdriver.Chrome(service=service, options=options)

def get_text_content(element):
    return ' '.join(element.stripped_strings)

def chunk_text(text, max_token_length):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    chunks = []
    start = 0
    while start < len(tokens):
        end = min(start + max_token_length, len(tokens))
        chunk = tokenizer.decode(tokens[start:end])
        chunks.append(chunk)
        start = end
    return chunks

def merge_small_chunks(chunks, max_token_length):
    merged_chunks = []
    temp_chunk = ""
    
    for chunk in chunks:
        if len(tokenizer.encode(temp_chunk + " " + chunk)) <= max_token_length:
            temp_chunk += " " + chunk
        else:
            # Ensure no chunk exceeds max_token_length
            while len(tokenizer.encode(temp_chunk)) > max_token_length:
                # Split the temp_chunk if it's too long
                split_point = max_token_length - 1  # Choose safe split point
                merged_chunks.append(tokenizer.decode(tokenizer.encode(temp_chunk)[:split_point]))
                temp_chunk = tokenizer.decode(tokenizer.encode(temp_chunk)[split_point:])
                
            merged_chunks.append(temp_chunk.strip())
            temp_chunk = chunk
    
    if temp_chunk:
        merged_chunks.append(temp_chunk.strip())
    
    return merged_chunks

def chunk_table(df, max_token_length, header_info):
    table_chunks = []
    current_chunk = header_info + ' ||| '  # Distinct marker between header and rows
    
    for _, row in df.iterrows():
        row_text = ' | '.join([str(cell) for cell in row if pd.notna(cell)])
        combined_text = current_chunk + row_text + ' || '
        
        if len(tokenizer.encode(combined_text)) <= max_token_length:
            current_chunk += row_text + ' || '
        else:
            # Split the row if adding it would exceed max_token_length
            row_chunks = chunk_text(row_text, max_token_length)
            for sub_chunk in row_chunks:
                if len(tokenizer.encode(current_chunk)) + len(tokenizer.encode(sub_chunk)) <= max_token_length:
                    current_chunk += sub_chunk + ' || '
                else:
                    table_chunks.append(current_chunk.strip())
                    current_chunk = header_info + ' ||| ' + sub_chunk + ' || '
                    
    if current_chunk:
        table_chunks.append(current_chunk.strip())
    
    return table_chunks


def scrape_and_chunk_page(driver, url):
    driver.get(url)
    time.sleep(1)
    soup = BeautifulSoup(driver.page_source, 'html.parser')
    
    chunks = []
    current_url = url
    last_header = ""

    elements = soup.find_all(['h1', 'h2', 'h3', 'h4', 'p', 'table'])
    for element in elements:
        if element.name in ['h1', 'h2', 'h3', 'h4']:
            header_text = get_text_content(element)
            last_header = header_text  # Store this as context for following elements
            header_chunks = chunk_text(header_text, max_token_length)
            chunks.extend([(chunk, current_url) for chunk in header_chunks])
            
        elif element.name == 'p':
            paragraph_text = get_text_content(element)
            paragraph_chunks = chunk_text(paragraph_text, max_token_length)
            chunks.extend([(chunk, current_url) for chunk in paragraph_chunks])
            
        elif element.name == 'table':
            table_html = StringIO(str(element))
            df = pd.read_html(table_html)[0]
            
            # Drop empty rows and columns
            df.dropna(axis=0, how='all', inplace=True)
            df.dropna(axis=1, how='all', inplace=True)
            
            # Ensure column headers are strings
            df.columns = [str(col) for col in df.columns]
            header_info = last_header + ' | ' + ' | '.join(df.columns) if not df.columns.empty else last_header
            
            # Chunk the table content
            table_chunks = chunk_table(df, max_token_length, header_info)
            chunks.extend([(chunk, current_url) for chunk in table_chunks])

    # Merge small chunks where possible
    text_chunks = [chunk[0] for chunk in chunks]
    final_chunks = merge_small_chunks(text_chunks, max_token_length)
    
    # Re-associate URLs after merging
    return [(chunk, current_url) for chunk in final_chunks]

def scrape_and_chunk(urls):
    driver = init_driver()
    chunks = []
    for url in tqdm(urls, desc="Scraping pages"):
        chunks.extend(scrape_and_chunk_page(driver, url))
    return chunks

urls = [
    "https://en.wikipedia.org/wiki/List_of_wars_by_death_toll",
    "https://en.wikipedia.org/wiki/List_of_wars:_1990%E2%80%932002",
    "https://en.wikipedia.org/wiki/List_of_wars:_1945%E2%80%931989",
    "https://en.wikipedia.org/wiki/List_of_wars:_1900%E2%80%931944",
    "https://en.wikipedia.org/wiki/List_of_wars:_2003%E2%80%93present",
    "https://en.wikipedia.org/wiki/List_of_wars:_1800%E2%80%931899",
    "https://en.wikipedia.org/wiki/List_of_wars:_1500%E2%80%931799",
    "https://en.wikipedia.org/wiki/List_of_wars:_1000%E2%80%931499",
    "https://en.wikipedia.org/wiki/List_of_wars:_before_1000",
]

scraped_chunks = scrape_and_chunk(urls)

print(f"Total Chunks: {len(scraped_chunks)}")

for chunk, url in scraped_chunks[:3]:
    print(f"Chunk: {chunk}\nSource URL: {url}\n")


Scraping pages:   0%|          | 0/9 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (528 > 512). Running this sequence through the model will result in indexing errors


Total Chunks: 299
Chunk: contents list of wars by death toll List of wars by death toll | 0 ||| Part of a series on || War (outline) || showHistory || showMilitary || showBattlespace || showWeapons || showTactics || showOperational || showStrategy || showGrand strategy || showAdministrative || showOrganization || showPersonnel || showLogistics || showScience || showLaw || showTheory || showNon-warfare || showCulture || showRelated || hideLists Battles Military occupations Military terms Operations Sieges War crimes Wars Weapons Writers || vte || this list of wars by death toll includes all deaths that are either directly or indirectly caused by war. these numbers include the deaths of military personnel which are the direct results of a battle or other military wartime actions, as well as wartime / war - related deaths of civilians which are often results of war - induced epidemics, famines, genocide, etc. due to incomplete records, the destruction of evidence, differing methods of cou

In [133]:
from collections import Counter
import re
import numpy as np
from qdrant_client import QdrantClient, models
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer
from IPython.display import display, clear_output, Markdown
import requests
import json
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

session = requests.Session()
retry = Retry(total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504])
adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter)
session.headers.update({"Connection": "keep-alive", "Content-Type": "application/json"})

qdrant_url = "http://localhost:6333"
collection_name = "wiki_collection"
ollama_url_gen = "http://localhost:11434/api/generate"
ollama_model_name = "llama3.2:latest"

client = QdrantClient(url=qdrant_url)
embedding_model = SentenceTransformer(model_path)

TOP_K = 5
TOP_N = 3

def get_embeddings(texts):
    return embedding_model.encode(texts, batch_size=32, show_progress_bar=True)

def create_collection(dimension):
    client.delete_collection(collection_name=collection_name)
    client.create_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
    )

def upsert_points_with_metadata(embeddings, chunks):
    points = [
        models.PointStruct(
            id=i,
            vector=embedding.tolist(),
            payload={"text": chunk, "url": url}
        ) for i, (embedding, (chunk, url)) in enumerate(zip(embeddings, chunks))
    ]
    client.upsert(collection_name=collection_name, points=points)

def store_in_qdrant_with_metadata(chunks):
    dimension = 384
    create_collection(dimension)
    chunk_texts = [chunk for chunk, _ in chunks]
    embeddings = get_embeddings(chunk_texts)
    upsert_points_with_metadata(embeddings, chunks)

def search_points_with_metadata(query_text, k=TOP_K):
    query_embedding = get_embeddings([query_text])[0]
    search_result = client.search(
        collection_name=collection_name,
        query_vector=query_embedding.tolist(),
        limit=k,
        with_payload=True
    )
    return [{"text": hit.payload["text"], "url": hit.payload["url"]} for hit in search_result]

def process_streamed_response(response, buffer_size=10):
    response_text, buffer = "", ""
    for chunk in response.iter_content(chunk_size=None):
        try:
            data = json.loads(chunk.decode('utf-8'))
            content = data.get("response", "")
            buffer += content

            if len(buffer) >= buffer_size:
                response_text += buffer
                clear_output(wait=True)
                display(Markdown(response_text))
                buffer = ""
                
        except json.JSONDecodeError:
            continue
            
    response_text += buffer
    clear_output(wait=True)
    display(Markdown(response_text))
    return response_text

def calculate_keyword_overlap_score(chunk, query_keywords):
    chunk_words = Counter(re.findall(r'\w+', chunk.lower()))
    overlap_count = sum(chunk_words.get(keyword, 0) for keyword in query_keywords)
    return overlap_count

def get_top_n_chunks_by_combined_score(query_text, retrieved_docs, n=TOP_N, semantic_weight=0.7, keyword_weight=0.3):
    query_keywords = set(re.findall(r'\w+', query_text.lower()))
    scored_chunks = []

    for doc in retrieved_docs:
        semantic_score = doc["score"]
        keyword_overlap_score = calculate_keyword_overlap_score(doc["text"], query_keywords)
        combined_score = (semantic_weight * semantic_score) + (keyword_weight * keyword_overlap_score)
        scored_chunks.append({"text": doc["text"], "url": doc["url"], "combined_score": combined_score})

    scored_chunks.sort(key=lambda n: n["combined_score"], reverse=True)
    print(f"Scores: {[n['combined_score'] for n in scored_chunks]}")
    return scored_chunks[:n]

def search_points_with_metadata(query_text, k=TOP_K, n=TOP_N, semantic_weight=0.8, keyword_weight=0.2):
    query_embedding = get_embeddings([query_text])[0]
    search_result = client.search(
        collection_name=collection_name,
        query_vector=query_embedding.tolist(),
        limit=k,
        with_payload=True
    )
    
    retrieved_docs = [{"text": hit.payload["text"], "url": hit.payload["url"], "score": hit.score} for hit in search_result]
    return get_top_n_chunks_by_combined_score(query_text, retrieved_docs, n=n, semantic_weight=semantic_weight, keyword_weight=keyword_weight)

def ask(query, k=5, n=3, verbose=False):
    retrieved_docs = search_points_with_metadata(query, k=k, n=n)
    combined_docs = "\n\n".join([f"Source: {doc['url']}\n\n{doc['text']}" for doc in retrieved_docs])
    inst = ("Instruction: If you do not find the answer within CONTEXT, please respond "
            "'Answer not found in context.' Do not speculate or create information beyond what is provided."
            "'Also respond naturally, dont start with phrase like, 'according to the context' or something similar.")
    rag_prompt = f"{inst}\n\n<CONTEXT>\n\n{combined_docs}\n\n</CONTEXT>\n\nQuery: {query}\n"

    if verbose:
        print(rag_prompt)
        
    payload = {"model": ollama_model_name, "prompt": rag_prompt, "stream": True}
    headers = {"Content-Type": "application/json"}

    response = session.post(ollama_url_gen, headers=headers, data=json.dumps(payload), stream=True)
    response_text = process_streamed_response(response) if response.status_code == 200 else "Request failed"
    
    return response_text

try:
    store_in_qdrant_with_metadata(scraped_chunks)
    print(f'Stored {len(scraped_chunks)} relevant chunks')
except Exception as e:
    print(f"Error storing in Qdrant: {e}")


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Stored 307 relevant chunks


In [134]:
_ = ask("Bangladesh Liberation War data?")

Bangladesh Liberation War: 0.3–3 million[49][50] | 1971 | India and Provisional Government of Bangladesh vs. Pakistan | Indian subcontinent

In [135]:
_ = ask("When was Federal War happened?")

Answer not found in context.

In [136]:
_ = ask("When did Quasi-War happend?")

The Quasi-War occurred from 1798 to 1800.

In [60]:
_ = ask("Where did Second Congo War happend?")

The Second Congo War took place in the Democratic Republic of the Congo.

In [95]:
_ = ask("What types of killings are excluded in the list?", k=5, n=1)

Mass killings, atrocities not explicitly classified as genocides, and genocides occurring outside of wartime, human sacrifices, ethnic cleansing operations, and acts of state terrorism or political repression during peacetime.

In [137]:
_ = ask("Show table format all data for Arab-Israeli conflict and Lebanese Civil War.", n=4)

Here is the requested data in table format:

| **Conflict** | **Started** | **Ended** | **Name of Conflict** | **Belligerents (Victorious party if applicable)** | **Belligerents (Defeated party if applicable)** |
| --- | --- | --- | --- | --- | --- |
| Second Intifada | 2000 | 2005 | Part of the Israeli–Palestinian conflict | Israel, Palestinian Authority, Fatah, PFLP, DFLP, Hamas, Islamic Jihad | Palestinian Authority, Fatah (al-Aqsa Martyrs' Brigades) |

Note: There is only one data point for the Arab-Israeli conflict and Lebanese Civil War in the provided text.

In [120]:
_ = ask("this 'list excludes mass killings and atrocities' of what types?")

According to the CONTEXT, the list excludes "mass killings and atrocities" that are not explicitly classified as genocides, such as:

* Genocides occurring outside of wartime
* Human sacrifices
* Ethnic cleansing operations
* Acts of state terrorism or political repression during peacetime.