In [4]:
import cg_rag

# baselines retrievers module
import cg_rag.baselines.retrievers as R
dir(R)

['BM25Okapi',
 'BM25Retriever',
 'DenseRetriever',
 'Embedder',
 'List',
 'Optional',
 'RetrievalResult',
 'Segment',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'np']

In [5]:
import inspect
print(inspect.getsource(R.DenseRetriever))

class DenseRetriever:
    """Simple dense retrieval baseline using cosine similarity."""

    def __init__(self, segments: List[Segment], embedder: Embedder, top_k: int = 3):
        self.segments = segments
        self.embedder = embedder
        self.top_k = top_k

        self.U = np.array([s.vector for s in segments])
        norms = np.linalg.norm(self.U, axis=1, keepdims=True)
        self.U_norm = self.U / np.maximum(norms, 1e-10)

    def _get_scores(self, query: str) -> np.ndarray:
        query_emb = self.embedder.embed([query], show_progress=False)[0]
        query_norm = np.linalg.norm(query_emb)
        query_normalized = query_emb / query_norm if query_norm > 1e-10 else query_emb
        return np.dot(self.U_norm, query_normalized)

    def retrieve(self, query: str, verbose: bool = False) -> RetrievalResult:
        scores = self._get_scores(query)
        top_indices = np.argsort(scores)[::-1][:self.top_k]

        return RetrievalResult(
            segments=[self.seg

In [6]:
#importing data
from beir import util
from beir.datasets.data_loader import GenericDataLoader
data_path = util.download_and_unzip( "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip","./datasets")

corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
len(corpus), len(queries)

  0%|          | 0/5183 [00:00<?, ?it/s]

(5183, 300)

In [7]:
import cg_rag.models.embedding as E
dir(E)

['Embedder',
 'List',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 'np',
 'torch']

In [8]:
# this the embedder i will use
print(inspect.getsource(E.Embedder))

class Embedder:
    """Wrapper for sentence-transformers embedding models."""

    def __init__(self, model_name: str = None):
        from sentence_transformers import SentenceTransformer

        # 1. Hardware Detection & Model Selection
        if torch.cuda.is_available():
            self.device = "cuda"
            # High-performance, large-context embedding model
            #default_model = 'Alibaba-NLP/gte-Qwen2-7B-instruct'
            default_model = 'all-MiniLM-L6-v2' # use this for debugging, is super fast
            print(f"GPU detected. Using high-fidelity embedder: {default_model}")
        else:
            self.device = "cpu"
            # Lightweight, highly efficient model for CPU
            #default_model = 'Alibaba-NLP/gte-Qwen2-1.5B-instruct'
            default_model = 'all-MiniLM-L6-v2' # use this for debugging, is super fast
            print(f"No GPU found. Using lightweight CPU embedder: {default_model}")

        self.model_name = model_name or default_mo

In [9]:
#creating segments stuff, trying to figure out the signature
import numpy as np
from beir import util
from beir.datasets.data_loader import GenericDataLoader
import cg_rag.baselines.retrievers as R
import cg_rag.models.embedding as E
import cg_rag.structures as S  # contains Segment, RetrievalResult, etc.
import inspect
print("Segment signature:", inspect.signature(S.Segment))
print("Segment annotations:", getattr(S.Segment, "__annotations__", None))

Segment signature: (text: str, vector: numpy.ndarray, start_idx: int, end_idx: int, sentences: List[str] = <factory>, internal_cost: float = 0.0) -> None
Segment annotations: {'text': <class 'str'>, 'vector': <class 'numpy.ndarray'>, 'start_idx': <class 'int'>, 'end_idx': <class 'int'>, 'sentences': typing.List[str], 'internal_cost': <class 'float'>}


In [10]:
data_path = util.download_and_unzip(
    "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip",
    "./datasets"
)
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
print("corpus:", len(corpus), "queries:", len(queries))

  0%|          | 0/5183 [00:00<?, ?it/s]

corpus: 5183 queries: 300


In [11]:
#testing out what the docs look like
doc_ids = list(corpus.keys())
doc_texts = [
    (corpus[doc_id].get("title","") + " " + corpus[doc_id].get("text","")).strip()
    for doc_id in doc_ids
]
print("Example doc:", doc_ids[0], doc_texts[0][:200])

Example doc: 4983 Microstructural development of human newborn cerebral white matter assessed in vivo by diffusion tensor magnetic resonance imaging. Alterations of the architecture of cerebral white matter in the deve


In [12]:
#embedding every document
embedder = E.Embedder() 
doc_vectors = embedder.embed(doc_texts, batch_size=64, show_progress=True)
print("doc_vectors shape:", doc_vectors.shape)

No GPU found. Using lightweight CPU embedder: all-MiniLM-L6-v2
Loading embedding model: None
Embedder ready (dim=384)


Batches:   0%|          | 0/81 [00:00<?, ?it/s]

doc_vectors shape: (5183, 384)


In [13]:
segments = []
segment_doc_ids = []  # same length as segments

for i, (doc_id, text, vec) in enumerate(zip(doc_ids, doc_texts, doc_vectors)):
    seg = S.Segment(
        text=text,
        vector=vec,
        start_idx=i,
        end_idx=i+1,
        sentences=[],          # optional
        internal_cost=0.0      # optional
    )
    segments.append(seg)
    segment_doc_ids.append(doc_id)

# Map python object identity -> index, so we can recover doc_id later
seg_id_to_idx = {id(seg): i for i, seg in enumerate(segments)}

print("segments:", len(segments))
print("example:", segments[0].start_idx, segments[0].end_idx, segment_doc_ids[0])

segments: 5183
example: 0 1 4983


In [14]:
dense = R.DenseRetriever(segments=segments, embedder=embedder, top_k=10)

In [15]:
def is_hit_at_k(retrieved_doc_ids, qrel_dict):
    gold = set(qrel_dict.keys())
    return any(d in gold for d in retrieved_doc_ids)

hits = []
dense_fail_qids = []

for qid, qtext in queries.items():
    result = dense.retrieve(qtext)  # RetrievalResult

    # Convert retrieved segments -> doc_ids using the id map
    retrieved_doc_ids = []
    for seg in result.segments:
        idx = seg_id_to_idx[id(seg)]
        retrieved_doc_ids.append(segment_doc_ids[idx])

    hit = is_hit_at_k(retrieved_doc_ids, qrels.get(qid, {}))
    hits.append(hit)
    if not hit:
        dense_fail_qids.append(qid)

print(f"Dense hit@10: {sum(hits)}/{len(hits)} = {sum(hits)/len(hits):.3f}")
print("Dense failures:", len(dense_fail_qids))
dense_fail_qids[:20]

Dense hit@10: 238/300 = 0.793
Dense failures: 62


['13',
 '48',
 '70',
 '128',
 '132',
 '198',
 '239',
 '294',
 '303',
 '312',
 '314',
 '384',
 '415',
 '421',
 '431',
 '437',
 '475',
 '502',
 '508',
 '517']

In [16]:
import pandas as pd

rows = []

for qid in dense_fail_qids:
    rows.append({
        "qid": qid,
        "query": queries[qid],
        "label": "",
        "notes": ""
    })

dense_fail_df = pd.DataFrame(rows)
dense_fail_df.head(10)

Unnamed: 0,qid,query,label,notes
0,13,5% of perinatal mortality is due to low birth ...,,
1,48,"A total of 1,000 people in the UK are asymptom...",,
2,70,Activation of PPM1D suppresses p53 function.,,
3,128,Arterioles have a larger lumen diameter than v...,,
4,132,Aspirin inhibits the production of PGE2.,,
5,198,CCL19 is absent within dLNs.,,
6,239,Cellular aging closely links to an older appea...,,
7,294,Crossover hot spots are not found within gene ...,,
8,303,DMRT1 is a sex-determining gene that is epigen...,,
9,312,De novo assembly of sequence data has more spe...,,


In [17]:
def inspect_dense_fail(qid, k=5):
    qtext = queries[qid]
    gold_ids = list(qrels.get(qid, {}).keys())

    result = dense.retrieve(qtext)
    retrieved = [
        segment_doc_ids[seg_id_to_idx[id(s)]]
        for s in result.segments[:k]
    ]

    print("="*70)
    print("QID:", qid)
    print("QUERY:", qtext)
    print("GOLD DOC IDS:", gold_ids[:5])

    print("\nTOP DENSE RESULTS:")
    for i, doc_id in enumerate(retrieved, 1):
        text = (
            corpus[doc_id].get("title","") +
            " " +
            corpus[doc_id].get("text","")
        )[:250]
        print(f"{i}. {doc_id} | {text}...")

    print("="*70)

In [18]:
for qid in dense_fail_qids[:30]:
    inspect_dense_fail(qid)

QID: 13
QUERY: 5% of perinatal mortality is due to low birth weight.
GOLD DOC IDS: ['1606628']

TOP DENSE RESULTS:
1. 7662395 | Perinatal mortality in rural China: retrospective cohort study. OBJECTIVES To explore the use of local civil registration data to assess the perinatal mortality in a typical rural county in a less developed province in China, 1999-2000. DESIGN Retros...
2. 4791384 | Neonatal Mortality Levels for 193 Countries in 2009 with Trends since 1990: A Systematic Analysis of Progress, Projections, and Priorities BACKGROUND Historically, the main focus of studies of childhood mortality has been the infant and under-five mo...
3. 1263446 | Determinants of neonatal mortality in Indonesia BACKGROUND Neonatal mortality accounts for almost 40 per cent of under-five child mortality, globally. An understanding of the factors related to neonatal mortality is important to guide the development...
4. 11748341 | Evidence-based interventions for improvement of maternal and child nut

In [32]:
label_map = {
    "Paraphrase": {
        "1", "13", "132", "238", "239", "384", "623", "800", "870", "913", "914",
        "975", "1049", "1088", "1110", "1191", "1199", "1213", "1226", "1241",
        "1278", "1279", "1316", "1332", "1368"
    },

    "Entities": {
        "431", "437", "517", "535", "544", "577", "690", "715", "716", "768",
        "775", "830", "831", "1200", "1280", "1281", "1382"
    },

    "Keyword overload": {
        "502", "820", "821", "887", "1196", "1197", "783", "1395"
    },

    "Rare terms": {
        "1140", "1221"
    }
}


In [33]:
type(dense_fail_qids[0]), dense_fail_qids[0]


(str, '13')

In [34]:
k = 10

def is_hit_at_k(retrieved_doc_ids, qrel_dict):
    gold = set(qrel_dict.keys())
    return any(d in gold for d in retrieved_doc_ids)

hits = []
dense_fail_qids = []
qid_categories = {}

for qid, qtext in queries.items():
    result = dense.retrieve(qtext)  # RetrievalResult

    # convert retrieved segments todoc_ids
    retrieved_doc_ids = []
    for seg in result.segments[:k]:
        idx = seg_id_to_idx[id(seg)]
        retrieved_doc_ids.append(segment_doc_ids[idx])

    hit = is_hit_at_k(retrieved_doc_ids, qrels.get(qid, {}))
    hits.append(hit)

    if not hit:
        dense_fail_qids.append(qid)
        qid_categories[qid] = {
            "dense": False,
            "query": qtext,
            "Failure Category": None  # fill manually via label_map
        }
    else:
        qid_categories[qid] = {
            "dense": True,
            "query": qtext,
            "Failure Category": "not a failure"
        }

print(f"Dense hit@{k}: {sum(hits)}/{len(hits)} = {sum(hits)/len(hits):.3f}")
print("Dense failures:", len(dense_fail_qids))
dense_fail_qids[:20]


Dense hit@10: 238/300 = 0.793
Dense failures: 62


['13',
 '48',
 '70',
 '128',
 '132',
 '198',
 '239',
 '294',
 '303',
 '312',
 '314',
 '384',
 '415',
 '421',
 '431',
 '437',
 '475',
 '502',
 '508',
 '517']

In [35]:
for qid in qid_categories:
    if not qid_categories[qid]["dense"]:  # only label failures
        for label, qids in label_map.items():
            if qid in qids:
                qid_categories[qid]["Failure Category"] = label
                break

# any failures not in label_map become "Unlabeled"
for qid in qid_categories:
    if (qid_categories[qid]["dense"] == False) and (qid_categories[qid]["Failure Category"] is None):
        qid_categories[qid]["Failure Category"] = "Unlabeled"


In [36]:
import pandas as pd

df = pd.DataFrame([
    {"qid": qid, **qid_categories[qid]}
    for qid in qid_categories
])

# only failures
fail_df = df[df["dense"] == False].copy()

fail_df["Failure Category"].value_counts()


Failure Category
Unlabeled           30
Paraphrase          15
Entities            13
Keyword overload     4
Name: count, dtype: int64

In [37]:
print("qid type:", type(list(qid_categories.keys())[0]))

for k, v in label_map.items():
    print("label_map type:", type(next(iter(v))))
    break


qid type: <class 'str'>
label_map type: <class 'str'>
