# BEIR data + colbert search
* model: modernbert based korean tuned model
    * https://huggingface.co/sigridjineth/ModernBERT-Korean-ColBERT-preview-v1
    * uses pylate package implementation

In [1]:
import json
from pathlib import Path
import time
from typing import Any, Dict, List, Optional

import jsonlines
import pandas as pd
from tqdm import tqdm

from config import settings

In [2]:
import sys
import os

parent_dir = os.path.dirname(os.getcwd())
core_src_dir = os.path.join(parent_dir, "src/psiking")
sys.path.append(core_src_dir)

In [3]:
## Import Core Schemas
from core.base.schema import Document, TextNode

# 1. Read Data

In [4]:
## Implement Custom Readers
from core.reader.base import BaseReader

class QuoraDataReader(BaseReader):
    def __init__(self):
        pass
    
    def read(self, data: dict, extra_info: Optional[dict] = None,) -> Optional[Document]:
        """Data format
        ['_id', 'title', 'text', 'metadata']
        """
        metadata = extra_info or {}
        
        text = data.get('text', '')
        if not text:
            return None
        node = TextNode(
            text=text,
            metadata=metadata
        )
        return Document(
            nodes=[node],
            metadata={
                "source_id": data['_id'],
                "title": data['title'],
                **metadata
            }
        )

    def run(self, file_path: str | Path,extra_info: Optional[dict] = None) -> List[Document]:
        metadata = extra_info or {}
        documents = []
        with jsonlines.open(file_path) as reader:
            for data in reader:
                document = self.read(data, extra_info={**metadata})
                if document:
                    documents.append(document)
        return documents

In [5]:
document_path = os.path.join(settings.data_dir, "beir/scifact/corpus.jsonl")

reader = QuoraDataReader()
documents = reader.run(document_path, extra_info={"source_file": "beir-scifact-corpus"})
print(len(documents))

5183


In [6]:
documents[:2]

[Document(id_='5902c383-8c1e-493a-a0a8-d213858a1ccb', metadata={'source_id': '4983', 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.', 'source_file': 'beir-scifact-corpus'}, nodes=[TextNode(id_='fdf323c0-e800-483f-ae33-1ad94008ec70', metadata={'source_file': 'beir-scifact-corpus'}, text_type=<TextType.PLAIN: 'plain'>, label=<TextLabel.PLAIN: 'plain'>, resource=MediaResource(data=None, text='Alterations of the architecture of cerebral white matter in the developing human brain can affect cortical development and result in functional disabilities. A line scan diffusion-weighted magnetic resonance imaging (MRI) sequence with diffusion tensor analysis was applied to measure the apparent diffusion coefficient, to calculate relative anisotropy, and to delineate three-dimensional fiber architecture in cerebral white matter in preterm (n = 17) and full-term infants (n = 7). To assess effects of premat

# 2. Run Splitter

In [7]:
# 3. Run Splitter
from core.splitter.text.langchain_text_splitters import LangchainRecursiveCharacterTextSplitter

splitter = LangchainRecursiveCharacterTextSplitter(
    chunk_size = 1024,
    chunk_overlap = 128
)

chunks = []
for document in documents:
    document_chunks = []
    source_id = document.id_
    for i, node in enumerate(document.nodes):
        # Run Splitter
        if isinstance(node, TextNode):
            split_nodes = splitter.run(node)
        else:
            split_nodes = [node]
        
        # Create New Document
        chunk = Document(
            nodes=split_nodes,
            metadata={
                "source_id": source_id,
                "source_file": document.metadata['source_file'],
                "title": document.metadata['title'],
            }
        )
        document_chunks.append(chunk)
    chunks.extend(document_chunks)
print(len(chunks))

5183


# 3. Format (Prepare Embedding Input)

In [8]:
from core.formatter.document.simple import SimpleTextOnlyFormatter

# use default templates
formatter = SimpleTextOnlyFormatter()
formatted_texts = formatter.run(chunks)

def select_embedding_input_idxs(texts: str, min_length: int = 20):
    return [i for i, x in enumerate(texts) if len(x.strip())>min_length]

embedding_input_idxs = select_embedding_input_idxs(
    texts=formatted_texts,
    min_length=20
)
print(len(embedding_input_idxs))

embedding_inputs = [formatted_texts[x] for x in embedding_input_idxs]

5183


# 4. Embedder

In [9]:
# Load Pylate Embedder
from pylate.models import ColBERT
from core.embedder.pylate import LocalPylateColBERTEmbedder

model_dir = os.path.join(settings.model_weight_dir, "late_interaction/ModernBERT-Korean-ColBERT-preview-v1")

# https://github.com/lightonai/pylate/blob/fe115ff8bd93351670d516859952804ced1198f7/pylate/models/colbert.py#L35
model = ColBERT(
    model_name_or_path=model_dir,
    embedding_size=128, # defaults to 128 if not set
    document_length=None, # don't set
    device="mps",
    prompts={"query": "query: ", "passage": "passage: "} # input prefix text
)

embedder = LocalPylateColBERTEmbedder(
    model=model
)

PyLate model loaded successfully.


In [10]:
# Calculate Embeddings
embeddings = embedder.run(
    texts=embedding_inputs,
    batch_size = 16,
    show_progress_bar=True
)
print(len(embeddings), len(embeddings[0]), len(embeddings[0][0]))

Encoding documents (bs=16):   0%|          | 0/324 [00:00<?, ?it/s]

5183 160 128


# 5. Add to VectorStore

## 5-1. Single 2D Vector Collection
* colbert returns 2D vector for each passage
* provide `MultiVectorConfig` with MaxSim Operation as comparator
* quantize to binary vectors

In [11]:
# 5-1. Add to VectorStore
from qdrant_client import QdrantClient
from core.storage.vectorstore.qdrant import QdrantLateInteractionVectorStore

# initialize client
client = QdrantClient(":memory:")
collection_name = "beir-scifact"

vector_store = QdrantLateInteractionVectorStore(
    collection_name=collection_name,
    client=client
)

In [12]:
from qdrant_client.http import models

embedding_dim = len(embeddings[0][0])

vector_store.create_collection(
    on_disk_payload=True,  # store the payload on disk
    vectors_config = models.VectorParams(
        size=embedding_dim,
        distance=models.Distance.COSINE,
        hnsw_config=models.HnswConfigDiff(
            m=0 #switching off HNSW, Number of edges per node in the index graph
        ),
        multivector_config=models.MultiVectorConfig(
            comparator=models.MultiVectorComparator.MAX_SIM #similarity metric between multivectors (matrices)
        ),
        quantization_config=models.BinaryQuantization(
            binary=models.BinaryQuantizationConfig(
                always_ram=False
            ),
        ),
        on_disk=True,
    )
)

In [13]:
vector_store.add(
    documents=[chunks[x] for x in embedding_input_idxs],
    embeddings=embeddings,
    metadata_keys=["source_file", "source_id", "title"]
)

In [14]:
# check collection
collection_info = vector_store._client.get_collection(
    collection_name=vector_store.collection_name
)
print(collection_info.model_dump_json(indent=4))

{
    "status": "green",
    "optimizer_status": "ok",
    "vectors_count": null,
    "indexed_vectors_count": 0,
    "points_count": 5183,
    "segments_count": 1,
    "config": {
        "params": {
            "vectors": {
                "size": 128,
                "distance": "Cosine",
                "hnsw_config": {
                    "m": 0,
                    "ef_construct": null,
                    "full_scan_threshold": null,
                    "max_indexing_threads": null,
                    "on_disk": null,
                    "payload_m": null
                },
                "quantization_config": {
                    "binary": {
                        "always_ram": false
                    }
                },
                "on_disk": true,
                "datatype": null,
                "multivector_config": {
                    "comparator": "max_sim"
                }
            },
            "shard_number": null,
            "sharding_method": null

In [15]:
# check point
points = vector_store._client.retrieve(
    collection_name=vector_store.collection_name,
    ids=[chunks[0].id_],
    with_vectors=True
)

In [16]:
print(points[0].id)
print(points[0].payload)
print(len(points[0].vector))

27dd4060-dd14-4f64-96a5-f281b4e3061a
{'source_id': '5902c383-8c1e-493a-a0a8-d213858a1ccb', 'source_file': 'beir-scifact-corpus', 'title': 'Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging.'}
160


## 5-2. Multi-Vector Collection
* calculate row/column means for each passage embedding matrix

Example
* https://qdrant.tech/documentation/advanced-tutorials/pdf-retrieval-at-scale/

In [None]:
# 5-1. Add to VectorStore
from qdrant_client import QdrantClient
from core.storage.vectorstore.qdrant import QdrantLateInteractionPooledVectorStore

# initialize client
client = QdrantClient(":memory:")
collection_name = "beir-scifact-pooled"

vector_store = QdrantLateInteractionPooledVectorStore(
    collection_name=collection_name,
    client=client
)

In [None]:
# Original 2D Vector
dense_config = models.VectorParams(
    size=embedding_dim,
    distance=models.Distance.COSINE,
    on_disk=True,
    hnsw_config=models.HnswConfigDiff(
        m=0 #switching off HNSW, Number of edges per node in the index graph
    ),
    multivector_config=models.MultiVectorConfig(
        comparator=models.MultiVectorComparator.MAX_SIM
    ),
    quantization_config=models.BinaryQuantization(
        binary=models.BinaryQuantizationConfig(
            always_ram=False
        ),
    ),
)

# Mean pooling configs
row_means_config = models.VectorParams(
    size=embedding_dim,
    distance=models.Distance.COSINE,
    multivector_config=models.MultiVectorConfig(
        comparator=models.MultiVectorComparator.MAX_SIM
    )
)
col_means_config = models.VectorParams(
    size=embedding_dim,
    distance=models.Distance.COSINE,
    multivector_config=models.MultiVectorConfig(
        comparator=models.MultiVectorComparator.MAX_SIM
    )
)


vectors_config = {
    "dense": dense_config,
    "dense_col_means": col_means_config,
    "dense_row_means": row_means_config,
}