In [2]:
import re
import pickle
from bs4 import BeautifulSoup
import pandas as pd
from transformers import AutoTokenizer
import os
from io import StringIO
import logging
from tqdm.notebook import tqdm

os.environ["TOKENIZERS_PARALLELISM"] = "false"

model_path = '/Users/hissain/git/github/models/all-MiniLM-L6-v2'
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Define maximum token length per chunk
max_token_length = 480

def get_text_content(element):
    return ' '.join(element.stripped_strings)

def chunk_text(text, max_token_length):
    tokens = tokenizer.encode(text, add_special_tokens=False)
    chunks = []
    start = 0
    while start < len(tokens):
        end = min(start + max_token_length, len(tokens))
        chunk = tokenizer.decode(tokens[start:end])
        chunks.append(chunk)
        start = end
    return chunks

def merge_small_chunks(chunks, max_token_length):
    merged_chunks = []
    temp_chunk = ""
    
    for chunk in chunks:
        if len(tokenizer.encode(temp_chunk + " " + chunk)) <= max_token_length:
            temp_chunk += " " + chunk
        else:
            while len(tokenizer.encode(temp_chunk)) > max_token_length:
                split_point = max_token_length - 1  # Choose safe split point
                merged_chunks.append(tokenizer.decode(tokenizer.encode(temp_chunk)[:split_point]))
                temp_chunk = tokenizer.decode(tokenizer.encode(temp_chunk)[split_point:])
                
            merged_chunks.append(temp_chunk.strip())
            temp_chunk = chunk
    
    if temp_chunk:
        merged_chunks.append(temp_chunk.strip())
    
    return merged_chunks

def chunk_table(df, max_token_length, header_info):
    table_chunks = []
    current_chunk = header_info + ' ||| '
    
    for _, row in df.iterrows():
        row_text = ' | '.join([str(cell) for cell in row if pd.notna(cell)])
        combined_text = current_chunk + row_text + ' || '
        
        if len(tokenizer.encode(combined_text)) <= max_token_length:
            current_chunk += row_text + ' || '
        else:
            # Split the row if adding it would exceed max_token_length
            row_chunks = chunk_text(row_text, max_token_length)
            for sub_chunk in row_chunks:
                if len(tokenizer.encode(current_chunk)) + len(tokenizer.encode(sub_chunk)) <= max_token_length:
                    current_chunk += sub_chunk + ' || '
                else:
                    table_chunks.append(current_chunk.strip())
                    current_chunk = header_info + ' ||| ' + sub_chunk + ' || '
                    
    if current_chunk:
        table_chunks.append(current_chunk.strip())
    
    return table_chunks


def scrape_and_chunk_page(content):

    soup = BeautifulSoup(content[1], 'html.parser') # index-1 for html
    
    chunks = []
    current_url = content[0] #index-0 for url
    last_header = ""

    elements = soup.find_all(['h1', 'h2', 'h3', 'h4', 'p', 'table'])
    for element in elements:
        if element.name in ['h1', 'h2', 'h3', 'h4']:
            header_text = get_text_content(element)
            last_header = header_text
            header_chunks = chunk_text(header_text, max_token_length)
            chunks.extend([(chunk, current_url) for chunk in header_chunks])
            
        elif element.name == 'p':
            paragraph_text = get_text_content(element)
            paragraph_chunks = chunk_text(paragraph_text, max_token_length)
            chunks.extend([(chunk, current_url) for chunk in paragraph_chunks])
            
        elif element.name == 'table':
            table_html = StringIO(str(element))
            df = pd.read_html(table_html)[0]
            
            df.dropna(axis=0, how='all', inplace=True)
            df.dropna(axis=1, how='all', inplace=True)
            
            df.columns = [str(col) for col in df.columns]
            header_info = last_header + ' | ' + ' | '.join(df.columns) if not df.columns.empty else last_header
            
            table_chunks = chunk_table(df, max_token_length, header_info)
            chunks.extend([(chunk, current_url) for chunk in table_chunks])

    text_chunks = [chunk[0] for chunk in chunks]
    final_chunks = merge_small_chunks(text_chunks, max_token_length)
    
    return [(chunk, current_url) for chunk in final_chunks]

def scrape_and_chunk(html_contents):
    chunks = []
    for content in tqdm(html_contents, desc="Scraping pages"):
        chunks.extend(scrape_and_chunk_page(content))
    return chunks

with open("html_contents.pkl", "rb") as f:
    html_contents = pickle.load(f)

print(f"Loaded {len(html_contents)} URLs from pickle file")
scraped_chunks = scrape_and_chunk(html_contents)

print(f"Total Chunks: {len(scraped_chunks)}")

for chunk, url in scraped_chunks[:3]:
    print(f"Chunk: {chunk}\nSource URL: {url}\n")

Loaded 9 URLs from pickle file


Scraping pages:   0%|          | 0/9 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (528 > 512). Running this sequence through the model will result in indexing errors


Total Chunks: 299
Chunk: contents list of wars by death toll List of wars by death toll | 0 ||| Part of a series on || War (outline) || showHistory || showMilitary || showBattlespace || showWeapons || showTactics || showOperational || showStrategy || showGrand strategy || showAdministrative || showOrganization || showPersonnel || showLogistics || showScience || showLaw || showTheory || showNon-warfare || showCulture || showRelated || hideLists Battles Military occupations Military terms Operations Sieges War crimes Wars Weapons Writers || vte || this list of wars by death toll includes all deaths that are either directly or indirectly caused by war. these numbers include the deaths of military personnel which are the direct results of a battle or other military wartime actions, as well as wartime / war - related deaths of civilians which are often results of war - induced epidemics, famines, genocide, etc. due to incomplete records, the destruction of evidence, differing methods of cou

In [55]:
import spacy
from collections import Counter
import re
import numpy as np
from qdrant_client import QdrantClient, models
from tqdm.notebook import tqdm
from sentence_transformers import SentenceTransformer
from IPython.display import display, clear_output, Markdown
import requests
import json
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from rank_bm25 import BM25Okapi

# Load SpaCy's English model for Named Entity Recognition
nlp = spacy.load("en_core_web_sm")

session = requests.Session()
retry = Retry(total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504])
adapter = HTTPAdapter(max_retries=retry)
session.mount("http://", adapter)
session.headers.update({"Connection": "keep-alive", "Content-Type": "application/json"})

qdrant_url = "http://localhost:6333"
collection_name = "wiki_collection"
ollama_url_gen = "http://localhost:11434/api/generate"
ollama_model_name = "llama3.2:latest"

client = QdrantClient(url=qdrant_url)
embedding_model = SentenceTransformer(model_path)

TOP_K = 10
TOP_N = 4
SYM_W = 0.8
SYN_W = 0.2
NE_BOOST_FACTOR = 2.5
NE_FULL_BOOST_FACTOR = 2

def get_embeddings(texts):
    return embedding_model.encode(texts, batch_size=32, show_progress_bar=True)

def create_collection(dimension):
    client.delete_collection(collection_name=collection_name)
    client.create_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(size=dimension, distance=models.Distance.COSINE),
    )

def upsert_points_with_metadata(embeddings, chunks):
    points = [
        models.PointStruct(
            id=i,
            vector=embedding.tolist(),
            payload={"text": chunk, "url": url}
        ) for i, (embedding, (chunk, url)) in enumerate(zip(embeddings, chunks))
    ]
    client.upsert(collection_name=collection_name, points=points)

def store_in_qdrant_with_metadata(chunks):
    dimension = 384
    create_collection(dimension)
    chunk_texts = [chunk for chunk, _ in chunks]
    embeddings = get_embeddings(chunk_texts)
    upsert_points_with_metadata(embeddings, chunks)

def search_points_with_metadata(query_text, k=TOP_K):
    query_embedding = get_embeddings([query_text])[0]
    search_result = client.search(
        collection_name=collection_name,
        query_vector=query_embedding.tolist(),
        limit=k,
        with_payload=True
    )
    return [{"text": hit.payload["text"], "url": hit.payload["url"], "score": hit.score} for hit in search_result]

def init_bm25(corpus_texts):
    tokenized_corpus = [text.split() for text in corpus_texts]
    return BM25Okapi(tokenized_corpus)

def calculate_bm25_scores(bm25, query_text):
    tokenized_query = query_text.split()
    return bm25.get_scores(tokenized_query)

def extract_named_entities(text):
    doc = nlp(text)
    return [ent.text for ent in doc.ents]

def boost_ne_scores(query_text, docs, bm25_scores, boost_factor=NE_BOOST_FACTOR, full_match_boost=NE_FULL_BOOST_FACTOR):
    query_entities = extract_named_entities(query_text)
    print(f"Query Named Entities: {query_entities}")
    
    boosted_scores = []
    for idx, (doc, bm25_score) in enumerate(zip(docs, bm25_scores)):
        doc_entities = extract_named_entities(doc["text"])
        matching_ne_count = sum(1 for ne in query_entities if ne in doc_entities)
        full_match = all(ne in doc_entities for ne in query_entities)
        ne_boost = 1 + (boost_factor * matching_ne_count)
        if full_match:
            ne_boost *= full_match_boost
        boosted_scores.append(bm25_score * ne_boost)

    print(f"Top-4 Boosted scores: {boosted_scores[:4]}")          
    return boosted_scores

def calculate_boosted_scores(query_text, retrieved_docs, bm25):
    bm25_scores = calculate_bm25_scores(bm25, query_text)
    return boost_ne_scores(query_text, retrieved_docs, bm25_scores)

def get_top_n_chunks_by_combined_score(query_text, retrieved_docs, n=TOP_N, semantic_weight=SYM_W, keyword_weight=SYN_W):
    
    bm25 = init_bm25([doc["text"] for doc in retrieved_docs])
    boosted_keyword_scores = calculate_boosted_scores(query_text, retrieved_docs, bm25)
        
    scored_chunks = []
    
    for idx, doc in enumerate(retrieved_docs):
        semantic_score = doc["score"]
        keyword_score = boosted_keyword_scores[idx]
        combined_score = (semantic_weight * semantic_score) + (keyword_weight * keyword_score)
        scored_chunks.append({"text": doc["text"], "url": doc["url"], "combined_score": combined_score})

    scored_chunks.sort(key=lambda n: n["combined_score"], reverse=True)
    print(f"Top-4 Combined scores: {[s['combined_score'] for s in scored_chunks[:4]]}")
    return scored_chunks[:n]

def search_points_with_metadata(query_text, k=TOP_K, n=TOP_N, semantic_weight=SYM_W, keyword_weight=SYN_W):
    query_embedding = get_embeddings([query_text])[0]
    search_result = client.search(
        collection_name=collection_name,
        query_vector=query_embedding.tolist(),
        limit=k,
        with_payload=True
    )
    
    retrieved_docs = [{"text": hit.payload["text"], "url": hit.payload["url"], "score": hit.score} for hit in search_result]
    
    return get_top_n_chunks_by_combined_score(query_text, retrieved_docs, n=n, semantic_weight=semantic_weight, keyword_weight=keyword_weight)

def process_streamed_response(response, buffer_size=10):
    response_text, buffer = "", ""
    for chunk in response.iter_content(chunk_size=None):
        try:
            data = json.loads(chunk.decode('utf-8'))
            content = data.get("response", "")
            buffer += content

            if len(buffer) >= buffer_size:
                response_text += buffer
                clear_output(wait=True)
                display(Markdown(response_text))
                buffer = ""
                
        except json.JSONDecodeError:
            continue
            
    response_text += buffer
    clear_output(wait=True)
    display(Markdown(response_text))
    return response_text

def inspect(query, k=TOP_K, n=TOP_N):
    retrieved_docs = search_points_with_metadata(query, k=k, n=n)
    combined_docs = "\n\n".join([f"Source: {doc['url']}\n\n{doc['text']}" for doc in retrieved_docs])
    rag_prompt = f"Documents:\n\n<context>\n\n{combined_docs}\n\n</context>\n\nQuestion: {query}\n\nAnswer:\n"
    print(rag_prompt)

def ask(query, k=TOP_K, n=TOP_N, verbose=False):
    retrieved_docs = search_points_with_metadata(query, k=k, n=n)
    combined_docs = "\n\n".join([f"Source: {doc['url']}\n\n{doc['text']}" for doc in retrieved_docs])
    inst = ("Instruction: If you do not find the answer within the following context, please respond,"
            "'Answer not found in the context.' without speculation or general knowledge."
            "'Do not start with phrase like, 'according to the context', or anything similar.")
    rag_prompt = f"{inst}\n\n<context>\n\n{combined_docs}\n\n</context>\n\nQuestion: {query}\n\nAnswer:\n"

    if verbose:
        print(rag_prompt)
        
    payload = {"model": ollama_model_name, "prompt": rag_prompt, "stream": True}
    headers = {"Content-Type": "application/json"}

    response = session.post(ollama_url_gen, headers=headers, data=json.dumps(payload), stream=True)
    response_text = process_streamed_response(response) if response.status_code == 200 else "Request failed"
    
    return response_text

try:
    store_in_qdrant_with_metadata(scraped_chunks)
    print(f'Stored {len(scraped_chunks)} relevant chunks')
except Exception as e:
    print(f"Error storing in Qdrant: {e}")

Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Stored 299 relevant chunks


In [56]:
inspect("When did Bangladesh Liberation War happend?", n=2)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query Named Entities: ['Bangladesh Liberation War']
Top-4 Boosted scores: [18.610997651532887, 2.7679512653215297, 0.9375919524608456, 0.8129783369707322]
Top-4 Combined scores: [4.198392866306578, 0.989499453064306, 0.7035941861932673, 0.6103937024921692]
Documents:

<context>

Source: https://en.wikipedia.org/wiki/List_of_wars_by_death_toll

List | War | Death range | Date | Combatants | Location ||| hundred years'war | 2. 3 – 3. 3 million [ 43 ] [ 44 ] | 1337 – 1453 | house of valois vs. house of plantagenet | western europe || Afghan conflict | 1.17–3 million[45][46][47] | 1978–present | Multiple sides; Afghan mujahideen, later Islamic Emirate of Afghanistan, United Tajik Opposition vs. Soviet Union, Democratic Republic of Afghanistan, Northern Alliance, Tajikistan, and the United States-led coalition | Afghanistan, Pakistan and Tajikistan || Delhi Conquest of North India | 0.5–3 million[48] | 1300–1310 | Delhi Sultanate vs. North Indian States | Indian subcontinent || Bangladesh L

In [57]:
inspect("How many died in Bangladesh Liberation War?", n=1)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query Named Entities: ['Bangladesh Liberation War']
Top-4 Boosted scores: [10.022025190870822, 0.0, 2.3242734798638764, 0.40616663692422533]
Top-4 Combined scores: [2.4975188781741644, 0.8838925359727754, 0.7208951161093443, 0.5936040625268316]
Documents:

<context>

Source: https://en.wikipedia.org/wiki/List_of_wars_by_death_toll

List | War | Death range | Date | Combatants | Location ||| hundred years'war | 2. 3 – 3. 3 million [ 43 ] [ 44 ] | 1337 – 1453 | house of valois vs. house of plantagenet | western europe || Afghan conflict | 1.17–3 million[45][46][47] | 1978–present | Multiple sides; Afghan mujahideen, later Islamic Emirate of Afghanistan, United Tajik Opposition vs. Soviet Union, Democratic Republic of Afghanistan, Northern Alliance, Tajikistan, and the United States-led coalition | Afghanistan, Pakistan and Tajikistan || Delhi Conquest of North India | 0.5–3 million[48] | 1300–1310 | Delhi Sultanate vs. North Indian States | Indian subcontinent || Bangladesh Liberation Wa

In [58]:
inspect("When was Federal War happened?", n=1)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query Named Entities: ['Federal War']
Top-4 Boosted scores: [0.857034638234949, 0.8209168045065249, 0.0, 0.5382903116934045]
Top-4 Combined scores: [0.8511307982715927, 0.5996745036469898, 0.558412536901305, 0.5325708909176011]
Documents:

<context>

Source: https://en.wikipedia.org/wiki/List_of_wars_by_death_toll

List | War | Death range | Date | Combatants | Location ||| irish nine year's war | 0. 13 million [ 219 ] | 1593 – 1603 | kingdom of england vs. irish rebels | ireland || Chaco War | 0.08–0.13 million[220][221][222] | 1932–1935 | Paraguay vs. Bolivia | Paraguay and Bolivia || Federal War | 0.1 million[223] | 1859–1863 | Federalists vs. Conservatives | Venezuela || Congo Crisis | 0.1 million[224] | 1960–1965 | Republic of the Congo, later Democratic Republic of the Congo, and allies vs. Free Republic of the Congo, South Kasai, Katanga, Kwilu rebels, Simba rebels, and allies | Republic of the Congo || Wars of Alexander the Great | 0.1 million[225][226][227] | 336 BCE–323 BCE |

In [59]:
inspect("When did Quasi-War happend?", n=1)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query Named Entities: ['Quasi-War']
Top-4 Boosted scores: [0.0, 0.0, 0.0, 0.0]
Top-4 Combined scores: [2.5552559846372835, 0.39421968, 0.339146064, 0.33548426400000003]
Documents:

<context>

Source: https://en.wikipedia.org/wiki/List_of_wars:_1800%E2%80%931899

contents list of wars : 1800 – 1899 this article provides a list of wars occurring between 1800 and 1899. conflicts of this era include the napoleonic wars in europe, the american civil war in north america, the taiping rebellion in asia, the paraguayan war in south america, the zulu war in africa, and the australian frontier wars in oceania. 1800 – 1810 1800–1810 | ('Start', 'Start') | ('Finish', 'Finish') | ('Name of conflict', 'Name of conflict') | ('Belligerents', 'Victorious party (if applicable)') | ('Belligerents', 'Defeated party (if applicable)') ||| 1765 | 1865 | Temne War[1] | British Empire Susu Tribes | Kingdom of Koya || 1798 | 1800 | Quasi-War | United States | France || 1801 | 1805 | Tripolitan War | United Stat

In [61]:
inspect("Where did Second Congo War happend?", n=1)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query Named Entities: ['Second Congo War']
Top-4 Boosted scores: [1.6547466984812753, 1.3513689639929223, 0.36892261359441214, 1.2945511262197553]
Top-4 Combined scores: [0.7562024596962551, 0.6739439527985844, 0.6698051461043588, 0.6522788012439511]
Documents:

<context>

Source: https://en.wikipedia.org/wiki/List_of_wars:_1990%E2%80%932002

List of wars: 1990–2002 | ('Started', 'Started') | ('Ended', 'Ended') | ('Name of Conflict', 'Name of Conflict') | ('Belligerents', 'Victorious party (if applicable)') | ('Belligerents', 'Defeated party (if applicable)') ||| 1997 | 1999 | republic of the congo civil war ( 1997 – 1999 ) | republic of the congo ( denis sassou nguesso government ) cobra militia rwandan hutu militia angola | republic of the congo ( pascal lissouba government ) cocoye militia ninja militia nsiloulou militia mamba militia || 1997 | 1997 | 1997 clashes in Cambodia | Hun Sen (CPP)  Vietnam | Norodom Ranariddh (FUNCINPEC)  Khmer Rouge || 1998 | 1998 | 1998 Monrovia clashes

In [66]:
inspect("What types of killings are excluded in the list?", n=3)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query Named Entities: []
Top-4 Boosted scores: [13.641569564221392, 2.4662067289372995, 2.4505186928687825, 3.160018360121778]
Top-4 Combined scores: [2.9974295928442785, 1.2861029083949669, 1.2499110962812823, 1.1433991803930343]
Documents:

<context>

Source: https://en.wikipedia.org/wiki/List_of_wars_by_death_toll

contents list of wars by death toll List of wars by death toll | 0 ||| Part of a series on || War (outline) || showHistory || showMilitary || showBattlespace || showWeapons || showTactics || showOperational || showStrategy || showGrand strategy || showAdministrative || showOrganization || showPersonnel || showLogistics || showScience || showLaw || showTheory || showNon-warfare || showCulture || showRelated || hideLists Battles Military occupations Military terms Operations Sieges War crimes Wars Weapons Writers || vte || this list of wars by death toll includes all deaths that are either directly or indirectly caused by war. these numbers include the deaths of military pe

In [65]:
inspect("Who was the Defeated party in South Ossetia War in Georgian–Ossetian conflict.", n=2)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Query Named Entities: ['Defeated', 'South Ossetia War', 'Georgian']
Top-4 Boosted scores: [10.279079170302243, 6.339361948341971, 20.846444494312113, 7.731310983422011]
Top-4 Combined scores: [4.488982794862423, 2.7371478859240748, 2.3823097540604485, 1.913235340564431]
Documents:

<context>

Source: https://en.wikipedia.org/wiki/List_of_wars:_1990%E2%80%932002

List of wars: 1990–2002 | ('Started', 'Started') | ('Ended', 'Ended') | ('Name of Conflict', 'Name of Conflict') | ('Belligerents', 'Victorious party (if applicable)') | ('Belligerents', 'Defeated party (if applicable)') ||| 1991 | 1994 | djiboutian civil war | djibouti france | front for the restoration of unity and democracy || 1991 | 1993 | Georgian Civil War | Georgian State Council  Russia | Zviadists  National Guard of Georgia || 1991 | 2002 | Algerian Civil War | Algerian government | Armed Islamic Group (GIA) || 1992 | 1992 | 1992 Venezuelan coup d'état attempts | Venezuela | Revolutionary Bolivarian Movement-200 || 199