## Imports

In [1]:
import os
import json
import logging
import numpy as np
from tqdm import tqdm
from sentence_transformers import util, SentenceTransformer, CrossEncoder
from langchain_text_splitters import (
    TextSplitter,
    # RecursiveCharacterTextSplitter,
    SentenceTransformersTokenTextSplitter,
)

import torch
from collections import Counter
from ranx import Qrels, Run, evaluate

## Types

In [2]:
from typing import Literal
from pydantic import BaseModel, Field


###########
# SOURCES #
###########
class Source(BaseModel):
    id: str
    url: str
    name: str
    desc: str
    type: Literal["moodle", "file", "web", "tg"] = Field("file")

    def __hash__(self) -> int:
        return self.id.__hash__()


class MoodleSource(Source):
    course_id: str
    course_url: str
    course_name: str
    type: Literal["moodle"] = Field("moodle")


class FileSource(Source):
    type: Literal["file"] = Field("file")


class WebSource(Source):
    type: Literal["web"] = Field("web")


class TelegramSource(Source):
    type: Literal["tg"] = Field("tg")


#########
# UTILS #
#########
class Chunk(BaseModel):
    index: int = Field(ge=0)
    source_id: str
    text: str


##########
# SEARCH #
##########
class SearchQuery(BaseModel):
    text: str


class SearchResult(BaseModel):
    source: Source
    distance: float

    def __hash__(self) -> int:
        return self.source.id.__hash__()

## Paths

In [3]:
from pathlib import Path

# meta datas
DATA_PATH = Path("../data")
META_FILE_PATH = DATA_PATH / "meta.json"

# text data
TEXTS_PATH = Path("../texts")
PREPROCESSED_PATH = Path("../preprocessed")

# all related to validation
VALIDATION_PATH = Path("../validation")

## Chunk

In [4]:
def load_sources_info(meta_file_path: Path) -> dict[str, Source]:
    with open(meta_file_path, "r", encoding="utf-8") as meta_file:
        meta_data = json.load(meta_file)

    sources_info: dict[str, Source] = {}
    for data in meta_data:
        source: Source = Source.model_validate_json(json.dumps(data), strict=True)
        sources_info[source.id] = source

    return sources_info

In [5]:
def load_chunks_info(meta_file_path: Path, texts_path: Path, text_splitter: TextSplitter) -> dict[int, Chunk]:
    # log missing files
    logging.basicConfig(
        filename="missing.log",
        filemode="w",
        level=logging.INFO,
        encoding="utf-8",
    )

    with open(meta_file_path, "r", encoding="utf-8") as meta_file:
        meta_data = json.load(meta_file)

    # get current available sources list
    sources: list[Source] = []
    for data in meta_data:
        source: Source = Source.model_validate_json(json.dumps(data), strict=True)
        sources.append(source)

    index = 0
    chunks_info: dict[int, Chunk] = {}
    for source in tqdm(sources, total=len(sources), desc="Split sources on chunks", unit="source"):
        source_text_path = texts_path / (source.name + ".txt")

        # save not found files into logs
        if not os.path.exists(source_text_path):
            logging.info(source.id)
            continue

        # otherwise get their content
        with open(source_text_path, "r", encoding="utf-8") as text_file:
            text = text_file.read()

        # update info dict with current source's chunks
        for chunk_text in text_splitter.split_text(text):
            chunk = Chunk(index=index, source_id=source.id, text=chunk_text)
            chunks_info[index] = chunk
            index += 1

    return chunks_info

In [6]:
def get_source_by_chunk(chunk_index: int, chunks_info: dict[int, Chunk], sources_info: dict[str, Source]) -> Source:
    if not chunks_info.get(chunk_index):
        raise ValueError(f"Chunk {chunk_index} not found")

    chunk: Chunk = chunks_info[chunk_index]

    if not sources_info.get(chunk.source_id):
        raise ValueError(f"Source {chunk.source_id} not found")

    return sources_info[chunk.source_id]

## Embeddings

In [7]:
def embed(
    texts: list[str],
    model: SentenceTransformer,
    show_progress_bar: bool = False,
) -> np.ndarray:
    embeddings: np.ndarray = model.encode(texts, batch_size=64, show_progress_bar=show_progress_bar)
    return embeddings

## Search

In [17]:
def search(
    query_embedding: np.ndarray,
    embeddings: np.ndarray,
    sources_info: dict[str, Source],
    chunks_info: dict[int, Chunk],
    strategy: Literal["base", "majority_vote"] = "base",
    threshold: float = 0.4,
    top_k: int = 100,
    k: int = 10,
) -> list[SearchResult]:
    """Outputs results of semantic search with reranking strategy used among given sources.

    Args:
        query_embedding (np.ndarray): vector representation of the query of size (1, {embedding_size}).
        embeddings (np.ndarray): vector representations of a corpus.
        chunks (list[Chunk]): little pieces of sources.
        strategy (Literal["base", "majority_vote"], optional): reranking strategy among found sources. Defaults to "base".
        threshold (float, optional): min value of similarity to be present on candidate. Defaults to 0.4.
        top_k (int, optional): retreive closest number of chunks. Defaults to 100.
        k (int, optional): final maximum number of sources. Defaults to 10.

    Raises:
        ValueError: signal that mentioned reranking strategy does not exist.

    Returns:
        list[SearchResult]: final output of sources.
    """
    results = util.semantic_search(query_embedding, embeddings, top_k=top_k, score_function=util.dot_score)
    assert len(results) == 1

    chunk_indexes = []
    search_results: list[SearchResult] = []
    for result in results[0]:
        chunk_index = result["corpus_id"]
        chunk_indexes.append(chunk_index)

        source: Source = get_source_by_chunk(chunk_index, chunks_info, sources_info)
        search_result = SearchResult(text="", source=source, distance=result["score"])
        search_results.append(search_result)

    # by distance
    new_results: list[SearchResult] = []
    if strategy == "base":
        added_source_ids = set()
        for search_result in search_results:
            if search_result.distance < threshold:  # skip if lower than threshold
                continue

            if search_result.source.id not in added_source_ids:
                added_source_ids.add(search_result.source.id)
                new_results.append(search_result)

    # count appearance of chunk's belonging to a source
    elif strategy == "majority_vote":
        # apply majority vote
        counter = Counter(
            [search_result.source for search_result in search_results if search_result.distance > threshold]
        )
        most_common = counter.most_common(10)

        # filter and leave unique documents (a bit of crutch O(n^2))
        for source, _ in most_common:
            for result in search_results:
                if source.id == result.source.id:
                    new_results.append(result)
                    break

    else:
        raise ValueError("Strategy is not supported")

    print(chunk_indexes)
    return new_results[:k]

## Utils

In [11]:
def print_text_file(texts_path: Path, source_name: str):
    source_path = texts_path / (source_name + ".txt")
    if not os.path.exists(source_path):
        print(f"File {source_path} not found")
        return

    with open(source_path, "r", encoding="utf-8") as file:
        print(file.read())

## Code

In [12]:
sources_info = load_sources_info(META_FILE_PATH)  # type: ignore

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# model
# MODEL_NAME = 'sentence-transformers/all-mpnet-base-v2'  # SOTA
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
model = SentenceTransformer(MODEL_NAME, device=device)

cuda


In [12]:
# from langchain_text_splitters import TextSplitter, RecursiveCharacterTextSplitter, Language

# separators = RecursiveCharacterTextSplitter.get_separators_for_language(Language.MARKDOWN) + \
#              RecursiveCharacterTextSplitter.get_separators_for_language(Language.HTML)
# text_splitter = RecursiveCharacterTextSplitter(
#     chunk_size=5000,
#     chunk_overlap=500,
#     len
#     add_start_index=True,
#     separators=separators
# )

In [13]:
# split on chunks
text_splitter = SentenceTransformersTokenTextSplitter(chunk_overlap=25, model_name=MODEL_NAME)
chunks_info = load_chunks_info(META_FILE_PATH, PREPROCESSED_PATH, text_splitter)

Split sources on chunks: 100%|██████████| 949/949 [01:24<00:00, 11.19source/s]


In [14]:
# embed text chunks
texts = [chunk.text for chunk in chunks_info.values()]
embeddings = embed(texts, model, True)
embeddings.shape

Batches:   0%|          | 0/341 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


(21791, 384)

In [15]:
for text in texts[50::5000]:
    print(text)

##го содержания объявления плакаты листовки т. п. том числе на входных дверях комнат. 4. 2. 20. проживающим запрещается вносить хранить употреблять на территории комплекса алкогольные напитки кальяны наркотические вещества вносить хранить огнестрельное холодное оружие том числе топоры охотничьи ножи т. п. вз
is a directed ij 3cycle containing vj. a directed vi vjpath of length ij followed by the directed cycle together form a directed vi vjwalk of length as desired. there are two special cases. if d. ij then followed by a directed 2cycle through vj followed by a directed 3cycle through vj constitute a directed vi vjwalk of length the 2cycle exists since and if ij then followed by a directed lcycle through vj followed by a directed 3cycle through vj constitute such a walk a real matrix is called primitive if for soe k. corollary 10. 7 the adjacency matrix a of a tournament is primitive if and only if is diconnected and v : 4. proof if is not diconnected then there are vertices vi and vj

## Test

In [18]:
query = SearchQuery(text="Burmykov Networks course lecture 11")
query_embedding = embed([query.text], model)

results: list[SearchResult] = search(
    query_embedding, embeddings, sources_info, chunks_info, strategy="base", threshold=0.35, top_k=100, k=10
)
results

[5142, 5116, 4974, 5059, 4847, 5143, 5085, 5014, 5083, 11415, 6000, 11111, 5078, 6003, 11427, 5084, 11107, 5058, 11346, 6006, 5080, 6482, 11349, 11871, 11407, 5070, 11437, 11943, 5062, 5081, 11364, 5054, 7909, 5109, 11371, 17890, 11353, 5082, 11511, 11330, 11283, 1466, 11360, 11542, 11963, 5016, 11106, 6009, 5140, 5022, 4774, 11378, 17811, 8757, 11412, 11339, 4795, 5023, 6478, 6976, 8821, 1585, 6205, 4997, 8381, 6991, 5066, 11347, 8788, 11300, 11516, 6008, 5144, 6001, 8391, 4956, 4772, 5053, 4906, 11399, 4775, 6002, 5055, 6004, 11509, 5087, 11859, 11340, 8566, 6965, 8557, 5108, 7151, 11401, 8779, 11390, 8576, 11109, 5544, 424]


[SearchResult(source=Source(id='module-79154.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=79154', name='module-79154.pdf', desc='Graph Theory (main book)', type='moodle'), distance=0.41664445400238037),
 SearchResult(source=Source(id='module-109689.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=109689', name='module-109689.pdf', desc='Lecture Week 11 Part I (TCP Congestion Control)', type='moodle'), distance=0.38167300820350647),
 SearchResult(source=Source(id='module-92783.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=92783', name='module-92783.pdf', desc="Lab 13. Dijkstra's algorithm", type='moodle'), distance=0.38154342770576477),
 SearchResult(source=Source(id='module-108403.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=108403', name='module-108403.pdf', desc='Lecture Week 2 Part I (Network Characteristics)', type='moodle'), distance=0.38054201006889343),
 SearchResult(source=Sou

In [34]:
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

chunks_indexes = [
    5142,
    5116,
    4974,
    5059,
    4847,
    5143,
    5085,
    5014,
    5083,
    11415,
    6000,
    11111,
    5078,
    6003,
    11427,
    5084,
    11107,
    5058,
    11346,
    6006,
    5080,
    6482,
    11349,
    11871,
    11407,
    5070,
    11437,
    11943,
    5062,
    5081,
    11364,
    5054,
    7909,
    5109,
    11371,
    17890,
    11353,
    5082,
    11511,
    11330,
    11283,
    1466,
    11360,
    11542,
    11963,
    5016,
    11106,
    6009,
    5140,
    5022,
    4774,
    11378,
    17811,
    8757,
    11412,
    11339,
    4795,
    5023,
]

used_ids = set()
filtered_chunks = []
# print(chunks_info[chunks[1]])
for chunk_index in chunks_indexes[:12]:
    if chunks_info[chunk_index].source_id not in used_ids:
        used_ids.add(chunks_info[chunk_index].source_id)
        filtered_chunks.append(chunks_info[chunk_index])

for chunk in filtered_chunks:
    print(chunk.text)
    print(cross_encoder.predict([("Burmykov Networks course lecture 11", chunk.text)], show_progress_bar=False))
# scores = cross_encoder.predict(
#     [("Query", "Paragraph1"), ("Query", "Paragraph2"), ("Query", "Paragraph3")]
# )
# scores

and chartrand g. 1971. introductionto the theory of graphs allyn and bacon boston harary f. cd. 1967. a seminar on graph theory holt rinehart and winston new york ore o. 1962. theory of graphs american mathematical society provi dence r. i. konig d. 1950. theorie der endlichen und unendlichen graphen chelsea new york sachs h. 1970. einfuhrung in die theorie der endlichen graphen tebner verlagsgesellsch. aft leipzig harary f. 1969. graph theory addisonwesley reading mass. berge c. 1973. graphs and hypergraphs north holland amsterdam special topics biggs n. 1974. algebraic graph theory cambridge university press cambridge tutte w. t. 1966. connectivity in graphs university of toronto press toronto ore o. 1967. the fourcolor problem academic press new york ringel g. 1974. map color theorem springerverlag berlin appendix v : suggestions for further reading 255 moon j. w. 1968. topics on tournaments holt rinehart and winston new york ford l. r. jr. and fulkerson d. r. 1962. flows in network

In [20]:
query = SearchQuery(text="SOLID principles examples")
query_embedding = embed([query.text], model)

results: list[SearchResult] = search(
    query_embedding,
    embeddings,
    sources_info,
    chunks_info,
    strategy="base",  # base
    threshold=0.35,
    k=10,
)
results

[SearchResult(source=Source(id='module-92034.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=92034', name='module-92034.pdf', desc='Tutorial 08 - SOLID', type='moodle'), distance=0.6408908367156982),
 SearchResult(source=Source(id='module-81745.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=81745', name='module-81745.pdf', desc='Book 3 - The World Philosophy Made', type='moodle'), distance=0.3889666795730591),
 SearchResult(source=Source(id='module-84088.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=84088', name='module-84088.pdf', desc='Class 9A. Presentation', type='moodle'), distance=0.372954398393631),
 SearchResult(source=Source(id='module-92214.pdf', url='https://moodle.innopolis.university/mod/resource/view.php?id=92214', name='module-92214.pdf', desc='2023 SSAD 12 Bridge, Flyweight', type='moodle'), distance=0.3594950735569),
 SearchResult(source=Source(id='module-82959.pdf', url='https://moodle.innopoli

In [23]:
query = SearchQuery(text="aboba sus amogus")
query_embedding = embed([query.text], model)

results: list[SearchResult] = search(
    query_embedding, embeddings, sources_info, chunks_info, strategy="base", threshold=0.375, k=10
)
results

[]

## Evaluating

In [24]:
class TestQuery(BaseModel):
    text: str
    relevant: bool
    sources: list[str] | None


def load_test_queries(validation_path: Path) -> list[TestQuery]:
    queries: list[TestQuery] = []
    with open(validation_path / "queries.jsonl", "r", encoding="utf-8") as file:
        for line in file:
            query = TestQuery.model_validate_json(line, strict=True)
            if query.relevant:
                queries.append(query)

    return queries

In [44]:
def evaluation(
    # paths
    meta_file_path: Path | None = None,
    validation_path: Path | None = None,
    texts_path: Path | None = None,
    # sources
    sources_info: dict[str, Source] | None = None,
    # chunks
    chunks_info: dict[int, Chunk] | None = None,
    text_splitter: TextSplitter | None = None,
    # model and embeddings
    model: SentenceTransformer | None = None,
    embeddings: np.ndarray | None = None,
    # search
    strategy: Literal["base", "majority_vote"] = "base",
    threshold: float = 0.4,
    k: int = 10,
    # eval part
    metrics: list[str] | None = None,
):
    ####################
    # PARAMETER CHECKS #
    ####################
    # perform source info extraction if not present
    if not sources_info:
        if not os.path.exists(meta_file_path):
            raise ValueError("Unable to perform source info extraction")
        sources_info = load_sources_info(meta_file_path)

    # perform chunking if not present
    if not chunks_info:
        if (
            not os.path.exists(meta_file_path)
            or not os.path.exists(texts_path)
            or not isinstance(text_splitter, TextSplitter)
        ):
            raise ValueError("Unable to perform chunking")
        chunks_info = load_chunks_info(meta_file_path, texts_path, text_splitter)

    # check provided model
    if not isinstance(model, SentenceTransformer):
        raise ValueError("Given model is not SentenceTransformer class")

    # perform chunks' texts embedding if not present
    if embeddings is None:
        texts = [chunk.text for chunk in chunks_info.values()]
        embeddings = embed(texts, model, True)

    ###########
    # QUERIES #
    ###########
    if not os.path.exists(validation_path):
        raise ValueError("Unable to find validation path")

    queries = load_test_queries(validation_path)

    qrels = Qrels(name="queries")
    run = Run(name="queries")

    test_query_embedding = embed([query.text for query in queries], model)
    for i, query in enumerate(queries):
        # get results
        results: list[SearchResult] = search(
            test_query_embedding[i].reshape(1, -1),
            embeddings,
            sources_info,
            chunks_info,
            strategy=strategy,
            threshold=threshold,
            k=k,
        )

        # extract ids to match ground truth ones
        result_ids = [result.source.id for result in results]
        if len(result_ids) == 0:
            result_ids = ["value that definetely is not present"]

        # add qrels (ground truth) and run (retrieved) to compare
        qrels.add(q_id=query.text, doc_ids=query.sources, scores=[i for i in range(len(query.sources), 0, -1)])
        run.add(q_id=query.text, doc_ids=result_ids, scores=[i for i in range(len(result_ids), 0, -1)])

    return evaluate(qrels, run, metrics)

## Validation

In [45]:
queries = load_test_queries(VALIDATION_PATH)
queries[49]

TestQuery(text='verilog syntax', relevant=True, sources=['module-84616.pdf', 'module-84621.pdf', 'module-84787.pdf'])

In [46]:
# ["hits@10", "hit_rate@10", "recall@10", "precision@10", "map@10"]
treshold_param_grids = [
    {
        "threshold": 0.4,
    },
    {
        "threshold": 0.375,
    },
    {
        "threshold": 0.35,
    },
]


for param_grid in treshold_param_grids:
    metrics = evaluation(
        meta_file_path=META_FILE_PATH,
        validation_path=VALIDATION_PATH,
        texts_path=TEXTS_PATH,
        sources_info=sources_info,
        chunks_info=chunks_info,
        model=model,
        embeddings=embeddings,
        metrics=["hit_rate@10", "map@10", "mrr@10"],
        **param_grid,
    )

    print(param_grid)
    for k, v in metrics.items():
        print(f"{k:<12}: {v:.2f}")
    print()

{'threshold': 0.4}
hit_rate@10 : 0.62
map@10      : 0.33
mrr@10      : 0.41

{'threshold': 0.375}
hit_rate@10 : 0.64
map@10      : 0.34
mrr@10      : 0.42

{'threshold': 0.35}
hit_rate@10 : 0.68
map@10      : 0.35
mrr@10      : 0.44



In [48]:
all_minilm_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
all_mpnet_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")

model_param_grids = [
    {
        "model": all_minilm_model,
        "embeddings": None,
    },
    {
        "model": all_mpnet_model,
        "embeddings": None,
    },
]

for param_grid in model_param_grids:
    metrics = evaluation(
        meta_file_path=META_FILE_PATH,
        validation_path=VALIDATION_PATH,
        texts_path=TEXTS_PATH,
        sources_info=sources_info,
        chunks_info=chunks_info,
        # model=model,
        # embeddings=embeddings,
        metrics=["hit_rate@10", "map@10", "mrr@10"],
        **param_grid,
    )

    print(param_grid)
    for k, v in metrics.items():
        print(f"{k:<12}: {v:.2f}")
    print()

Batches:   0%|          | 0/658 [00:00<?, ?it/s]

{'model': SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
), 'embeddings': None}
hit_rate@10 : 0.62
map@10      : 0.33
mrr@10      : 0.41



Batches:   0%|          | 0/658 [00:00<?, ?it/s]

{'model': SentenceTransformer(
  (0): Transformer({'max_seq_length': 384, 'do_lower_case': False}) with Transformer model: MPNetModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
), 'embeddings': None}
hit_rate@10 : 0.70
map@10      : 0.35
mrr@10      : 0.47

