## Imports

In [2]:
import os
import json
import logging
import numpy as np
from tqdm import tqdm
from sentence_transformers import util, SentenceTransformer
from langchain_text_splitters import (
    TextSplitter,
    # RecursiveCharacterTextSplitter,
    SentenceTransformersTokenTextSplitter,
)

import torch
from collections import Counter

  from tqdm.autonotebook import tqdm, trange


## Types

In [3]:
from typing import Literal
from pydantic import BaseModel, Field


###########
# SOURCES #
###########
class Source(BaseModel):
    id: str
    url: str
    name: str
    desc: str
    type: Literal["moodle", "file", "web", "tg"] = Field("file")

    def __hash__(self) -> int:
        return self.id.__hash__()


class MoodleSource(Source):
    course_id: str
    course_url: str
    course_name: str
    type: Literal["moodle"] = Field("moodle")


class FileSource(Source):
    type: Literal["file"] = Field("file")


class WebSource(Source):
    type: Literal["web"] = Field("web")


class TelegramSource(Source):
    type: Literal["tg"] = Field("tg")


#########
# UTILS #
#########
class Chunk(BaseModel):
    index: int = Field(ge=0)
    source_id: str
    text: str


##########
# SEARCH #
##########
class SearchQuery(BaseModel):
    text: str


class SearchResult(BaseModel):
    source: Source
    distance: float

    def __hash__(self) -> int:
        return self.source.id.__hash__()

## Paths

In [4]:
from pathlib import Path

# meta datas
DATA_PATH = Path("../data")
META_FILE_PATH = DATA_PATH / "meta.json"

# text data
TEXTS_PATH = Path("../texts")
PREPROCESSED_PATH = Path("../preprocessed")

# all related to validation
VALIDATION_PATH = Path("../validation")

## Chunk

In [5]:
def load_sources_info(meta_file_path: Path) -> dict[str, Source]:
    with open(meta_file_path, "r", encoding="utf-8") as meta_file:
        meta_data = json.load(meta_file)

    sources_info: dict[str, Source] = {}
    for data in meta_data:
        source: Source = Source.model_validate_json(json.dumps(data), strict=True)
        sources_info[source.id] = source

    return sources_info

In [11]:
def load_chunks_info(meta_file_path: Path, texts_path: Path, text_splitter: TextSplitter) -> dict[int, Chunk]:
    # log missing files
    logging.basicConfig(
        filename="missing.log",
        filemode="w",
        level=logging.INFO,
        encoding="utf-8",
    )

    with open(meta_file_path, "r", encoding="utf-8") as meta_file:
        meta_data = json.load(meta_file)

    # get current available sources list
    sources: list[Source] = []
    for data in meta_data:
        source: Source = Source.model_validate_json(json.dumps(data), strict=True)
        sources.append(source)

    index = 0
    chunks_info: dict[int, Chunk] = {}
    for source in tqdm(sources, total=len(sources), desc="Split sources on chunks", unit="source"):
        source_text_path = texts_path / (source.name + ".txt")

        # save not found files into logs
        if not os.path.exists(source_text_path):
            logging.info(source.id)
            continue

        # otherwise get their content
        with open(source_text_path, "r", encoding="utf-8") as text_file:
            text = text_file.read()

        # update info dict with current source's chunks
        for chunk_text in text_splitter.split_text(text):
            chunk = Chunk(index=index, source_id=source.id, text=chunk_text)
            chunks_info[index] = chunk
            index += 1

    return chunks_info

In [13]:
def get_source_by_chunk(chunk_index: int, chunks_info: dict[int, Chunk], sources_info: dict[str, Source]) -> Source:
    if not chunks_info.get(chunk_index):
        raise ValueError(f"Chunk {chunk_index} not found")

    chunk: Chunk = chunks_info[chunk_index]

    if not sources_info.get(chunk.source_id):
        raise ValueError(f"Source {chunk.source_id} not found")

    return sources_info[chunk.source_id]

## Embeddings

In [63]:
def embed(
    texts: list[str],
    model: SentenceTransformer,
) -> np.ndarray:
    embeddings: np.ndarray = model.encode(texts, batch_size=64)  # show_progress_bar=False
    return embeddings

## Search

In [15]:
def search(
    query_embedding: np.ndarray,
    embeddings: np.ndarray,
    sources_info: dict[str, Source],
    chunks_info: dict[int, Chunk],
    strategy: Literal["base", "majority_vote"] = "base",
    threshold: float = 0.4,
    k: int = 10,
) -> list[SearchResult]:
    """Outputs results of semantic search with reranking strategy used among given sources.

    Args:
        query_embedding (np.ndarray): vector representation of the query of size (1, {embedding_size}).
        embeddings (np.ndarray): vector representations of a corpus.
        chunks (list[Chunk]): little pieces of sources.
        strategy (Literal["base", "majority_vote"], optional): reranking strategy among found sources. Defaults to "base".
        threshold (float, optional): min value of similarity to be present on candidate. Defaults to 0.4.
        k (int, optional): final maximum number of sources. Defaults to 10.

    Raises:
        ValueError: signal that mentioned reranking strategy does not exist.

    Returns:
        list[SearchResult]: final output of sources.
    """
    results = util.semantic_search(query_embedding, embeddings, top_k=50, score_function=util.dot_score)
    assert len(results) == 1

    search_results: list[SearchResult] = []
    for result in results[0]:
        chunk_index = result["corpus_id"]
        source: Source = get_source_by_chunk(chunk_index, chunks_info, sources_info)
        search_result = SearchResult(text="", source=source, distance=result["score"])
        search_results.append(search_result)

    # by distance
    new_results: list[SearchResult] = []
    if strategy == "base":
        added_source_ids = set()
        for search_result in search_results:
            if search_result.distance < threshold:  # skip if lower than threshold
                continue

            if search_result.source.id not in added_source_ids:
                added_source_ids.add(search_result.source.id)
                new_results.append(search_result)

    # count appearance of chunk's belonging to a source
    elif strategy == "majority_vote":
        # apply majority vote
        counter = Counter(
            [search_result.source for search_result in search_results if search_result.distance > threshold]
        )
        most_common = counter.most_common(10)

        # filter and leave unique documents (a bit of crutch O(n^2))
        for source, _ in most_common:
            for result in search_results:
                if source.id == result.source.id:
                    new_results.append(result)
                    break
    else:
        raise ValueError("Strategy is not supported")

    return new_results[:k]

## Utils

In [22]:
def print_text_file(texts_path: Path, source_name: str):
    source_path = texts_path / (source_name + ".txt")
    if not os.path.exists(source_path):
        print(f"File {source_path} not found")
        return

    with open(source_path, "r", encoding="utf-8") as file:
        print(file.read())

## Code

In [16]:
# from langchain_text_splitters import TextSplitter, RecursiveCharacterTextSplitter, Language

# separators = RecursiveCharacterTextSplitter.get_separators_for_language(Language.MARKDOWN) + \
#              RecursiveCharacterTextSplitter.get_separators_for_language(Language.HTML)
# text_splitter = RecursiveCharacterTextSplitter(
#     chunk_size=5000,
#     chunk_overlap=500,
#     len
#     add_start_index=True,
#     separators=separators
# )

# sources info
sources_info = load_sources_info(META_FILE_PATH)  # type: ignore

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# model
# MODEL_NAME = 'all-mpnet-base-v2'  # SOTA
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(MODEL_NAME, device=device)

# split on chunks
text_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=50, model_name=MODEL_NAME)
# chunks = chunk(text_splitter)
chunks_info = load_chunks_info(META_FILE_PATH, TEXTS_PATH, text_splitter)

# embed text chunks
texts = [chunk.text for chunk in chunks_info.values()]
embeddings = embed(texts, model)
embeddings.shape

cuda


Split sources on chunks: 100%|██████████| 949/949 [01:29<00:00, 10.64source/s] 


Batches:   0%|          | 0/1470 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


(47011, 384)

## Test

In [17]:
query = SearchQuery(text="Burmykov Networks course lecture 11")
query_embedding = embed([query.text], model)

results: list[SearchResult] = search(
    query_embedding, embeddings, sources_info, chunks_info, strategy="base", threshold=0.4, k=10
)
results

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[SearchResult(source=Source(id='module-78518.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=78518', name='module-78518.pdf', desc='Lab 4 (bul)', type='moodle'), distance=0.48109304904937744),
 SearchResult(source=Source(id='module-79154.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=79154', name='module-79154.pdf', desc='Graph Theory (main book)', type='moodle'), distance=0.4612114131450653),
 SearchResult(source=Source(id='https://innopolis.university/upload/iblock/0b3/92tmskeuzgrgqcs5r64xlb1g8g9i4tf7/Отчет_о_результатах_самообследования__2024__на_сайт.pdf', url='https://innopolis.university/upload/iblock/0b3/92tmskeuzgrgqcs5r64xlb1g8g9i4tf7/Отчет_о_результатах_самообследования__2024__на_сайт.pdf', name='Отчет_о_результатах_самообследования__2024__на_сайт.pdf', desc='Отчет о результатах самообследования Автономной некоммерческой организацией высшего образования «Университет Иннополис» за 2023 год', type='file'), distance=0.447313666343

In [18]:
query = SearchQuery(text="SOLID principles examples")
query_embedding = embed([query.text], model)

results: list[SearchResult] = search(
    query_embedding,
    embeddings,
    sources_info,
    chunks_info,
    strategy="base",  # base
    threshold=0.4,
    k=10,
)
results

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[SearchResult(source=Source(id='module-92034.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=92034', name='module-92034.pdf', desc='Tutorial 08 - SOLID', type='moodle'), distance=0.46830806136131287),
 SearchResult(source=Source(id='module-92216.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=92216', name='module-92216.pdf', desc='2023 SSAD 14 Command, Chain or Resp, SOLID', type='moodle'), distance=0.4331497550010681)]

In [20]:
query = SearchQuery(text="aboba sus amogus")
query_embedding = embed([query.text], model)

results: list[SearchResult] = search(
    query_embedding, embeddings, sources_info, chunks_info, strategy="base", threshold=0.4, k=10
)
results

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

[]

## Evaluating

In [67]:
class TestQuery(BaseModel):
    text: str
    relevant: bool
    sources: list[str] | None


def load_test_queries(validation_path: Path) -> list[TestQuery]:
    queries: list[TestQuery] = []
    with open(validation_path / "queries.jsonl", "r", encoding="utf-8") as file:
        for line in file:
            query = TestQuery.model_validate_json(line, strict=True)
            if query.relevant:
                queries.append(query)

    return queries

In [68]:
queries = load_test_queries(VALIDATION_PATH)

queries[49]

TestQuery(text='verilog syntax', relevant=True, sources=['module-84616.pdf', 'module-84621.pdf', 'module-84787.pdf'])

In [64]:
from ranx import Qrels, Run, evaluate

qrels = Qrels(name="queries")
run = Run(name="queries")

test_query_embedding = embed([query.text for query in queries], model)
for i, query in enumerate(queries):  # [49:51]
    results: list[SearchResult] = search(
        test_query_embedding[i].reshape(1, -1),
        embeddings,
        sources_info,
        chunks_info,
        strategy="base",
        threshold=0.4,
        k=10,
    )

    qrels.add(q_id=query.text, doc_ids=query.sources, scores=[i for i in range(len(query.sources), 0, -1)])

    result_ids = [result.source.id for result in results]
    if len(result_ids) == 0:
        result_ids = ["test"]
    run.add(q_id=query.text, doc_ids=result_ids, scores=[i for i in range(len(result_ids), 0, -1)])

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [65]:
metrics = evaluate(qrels, run, ["hits@10", "hit_rate@10", "recall@10", "precision@10", "map@10"])

for k, v in metrics.items():
    print(f"{k:<12}: {v:.2f}")

hits@10     : 0.88
hit_rate@10 : 0.62
recall@10   : 0.52
precision@10: 0.09
map@10      : 0.34
