# Building contrastive examples

In [1]:
import pandas as pd
import spacy

df = pd.read_csv('../../data/raw/hyperbook.csv', index_col='url')
nlp = spacy.load('en_core_sci_md')  # This can take a while

  deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(  # type: ignore[union-attr]


In [165]:
from collections import deque

def group_intersecting_sets(sets_list):
    # Keep track of which sets have been visited
    visited = [False] * len(sets_list)
    result = []

    for i, s in enumerate(sets_list):
        if visited[i]:
            continue

        # Start a new component
        group = []
        queue = deque([i])
        visited[i] = True

        while queue:
            idx = queue.popleft()
            group.append(sets_list[idx])

            for j, other_set in enumerate(sets_list):
                if not visited[j] and not sets_list[idx].isdisjoint(other_set):
                    visited[j] = True
                    queue.append(j)

        result.append(group)

    return result

authors = [set(authors.split(',')) for authors in df['authors']]
triplets = group_intersecting_sets(authors)

In [145]:
import re

FIG_RE = re.compile(r"^(fig(ure)?|table|eq(uation)?)\s*\d+", re.I)
CIT_RE = re.compile(r"^\[?\d{1,3}\]?$")

def is_good(sent: str) -> bool:
    """
    Validates a sentence and returns a boolean.
    Excludes figures, citations, section heads, and
    sentences with alphabetic ratios less than 40%,
    less than 4 tokens or more than 128 tokens.
    """
    text = sent.strip()
    # Exclude figures and citations
    if FIG_RE.match(text) or CIT_RE.match(text): return False
    # Exclude section heads
    if text.isupper() and len(text.split()) <= 6: return False
    doc = nlp(text)
    if len(doc) < 4 or len(doc) > 128: return False
    # Ensure alphabetic ratio greater than 40%
    alpha_ratio = sum(t.is_alpha for t in doc)/len(doc) 
    return alpha_ratio >= 0.4

def clean_paper(content):
    doc = nlp(content)
    for sent in doc.sents:
        if is_good(sent.text):
            yield sent.text.strip()

In [185]:
organism_maps = {}

papers = list(df['content'])

for pid in range(len(papers)):
    for organism_id, triplet in enumerate(triplets):
        if authors[pid] in triplet:
            organism_maps[pid] = organism_id

In [None]:
from tqdm import tqdm

sentences, positives, paper_ids = [], [], []

for pid, content in tqdm(list(enumerate(papers))):
    sents = list(clean_paper(content))
    # Last sentence has no positive
    for i in range(len(sents) - 1):
        sentences.append(sents[i])
        positives.append(sents[i + 1])
        paper_ids.append(pid)


100%|██████████| 207/207 [10:55<00:00,  3.17s/it]


In [190]:
organism_ids = [organism_maps[pid] for pid in paper_ids]

In [191]:
import json

data = {
    'sentences': sentences, 
    'paper_ids': paper_ids, 
    'organism_ids': organism_ids, 
    'positives': positives
}

output_path = '../../data/processed/indexed_sentences.json'

with open(output_path, 'w') as file:
    json.dump(data, file)

In [None]:
import json 

output_path = '../../data/processed/indexed_sentences.json'

with open(output_path) as file:
    data = json.load(file)

sentences, positives, paper_ids, organism_ids = data['sentences'], data['positives'], data['paper_ids'], data['organism_ids']

## Hard-negative mining for InfoNCE

In [152]:
from sentence_transformers import SentenceTransformer

model_name = 'Qwen/Qwen3-Embedding-0.6B'
model = SentenceTransformer(model_name)

In [132]:
embeddings = model.encode(
    sentences, 
    batch_size=64, 
    normalize_embeddings=True,  # Normalizes into unit vectors
    show_progress_bar=True
)

Batches:   0%|          | 0/800 [00:00<?, ?it/s]

In [133]:
import faiss

d = embeddings.shape[1]
index = faiss.IndexFlatIP(d)    # Cosine because vectors are unit-norm
index.add(embeddings)

In [157]:
embeddings.shape

(51146, 1024)

In [195]:
def mine_negatives(i, k=3, lower=0.6, upper=0.9):
    # Get the top 50 candidates
    D, I = index.search(embeddings[i:i+1], 50)
    hard = []

    for j, score in zip(I[0], D[0]):
        if j == i: continue
        if paper_ids[j] == paper_ids[i]: continue
        if organism_ids[j] == organism_ids[i]: continue
        if lower <= score <= upper: hard.append(sentences[j])
        if len(hard) == k: break

    return hard

In [196]:
import json

output_path = '../../data/processed/hyperbook_infonce.jsonl'

with open(file=output_path, mode='w', encoding='utf-8') as file:
    for i, (anchor, positive) in tqdm(list(enumerate(zip(sentences, positives)))):
        negatives = mine_negatives(i)
        file.write(json.dumps({
            'query': anchor,
            'response': positive,
            'rejected_response': negatives
        }, ensure_ascii=False) + '\n')

100%|██████████| 51146/51146 [12:18<00:00, 69.29it/s]


In [197]:
import pandas as pd
import json

data = []

with open('../../data/processed/hyperbook_infonce.jsonl', 'r', encoding='utf-8') as json_file:
    json_list = list(json_file)

for json_str in json_list:
    data.append(json.loads(json_str))