In [1]:
!pip install sentence-transformers datasets rank-bm25 tqdm




In [2]:
from datasets import load_dataset

ds = load_dataset("mattmorgis/bioasq-12b-rag", "question-answer-passages")
train_ds = ds["dev"]

In [None]:
from tqdm import tqdm

positive_pairs = []

for ex in tqdm(train_ds):
    q = ex["question"]

    snippets = ex.get("snippets", [])
    for s in snippets:
        text = s.get("text", "")

        if text.strip():
            positive_pairs.append({
                "query": q,
                "positive": text
            })

100%|██████████| 5049/5049 [00:01<00:00, 4110.37it/s]


In [4]:
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

corpus = [p["positive"] for p in positive_pairs]

vectorizer = TfidfVectorizer(max_features=50000, stop_words='english')
corpus_embeddings = vectorizer.fit_transform(corpus)

In [5]:
import random
from sklearn.metrics.pairwise import linear_kernel

triples = []
queries = [p["query"] for p in positive_pairs]
query_embeddings = vectorizer.transform(queries)

batch_size = 1024
for i in tqdm(range(0, len(queries), batch_size)):
    end = min(i + batch_size, len(queries))
    batch_scores = linear_kernel(query_embeddings[i:end], corpus_embeddings)

    for j, idx in enumerate(range(i, end)):
        scores = batch_scores[j]
        q = queries[idx]
        pos = positive_pairs[idx]["positive"]

        top_k_indices = np.argpartition(scores, -50)[-50:]

        cand = []
        for cand_idx in top_k_indices:
            if corpus[cand_idx] != pos:
                cand.append(corpus[cand_idx])

        if cand:
            neg = random.choice(cand)
            triples.append({
                "query": q,
                "positive": pos,
                "negative": neg
            })


100%|██████████| 59/59 [03:06<00:00,  3.16s/it]


In [6]:
from sentence_transformers import InputExample

train_samples = []
for t in triples:
    train_samples.append(
        InputExample(
            texts=[t["query"], t["positive"]],
            label=1.0
        )
    )
    train_samples.append(
        InputExample(
            texts=[t["query"], t["negative"]],
            label=0.0
        )
    )


In [7]:
from sentence_transformers import CrossEncoder
from torch.utils.data import DataLoader

model_name = "dmis-lab/biobert-base-cased-v1.1"
model = CrossEncoder(model_name, num_labels=1)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
import os

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=16)

os.environ["WANDB_DISABLED"] = "true"

model.fit(
    train_dataloader=train_dataloader,
    epochs=2,
    warmup_steps=100,
    output_path="/content/model"
)

model.save("/content/model")

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss
500,0.5724
1000,0.5157
1500,0.4983
2000,0.4888
2500,0.4844
3000,0.4777
3500,0.4792
4000,0.4712
4500,0.4573
5000,0.4555
