In [3]:
import gzip, csv, json, re, itertools, random, time
from pathlib import Path
import torch
from tqdm.auto import tqdm
from sentence_transformers import models, SentenceTransformer, InputExample, losses, util
from torch.utils.data import DataLoader
from datasets import Dataset, DatasetDict

import sentence_transformers.fit_mixin as fit_mixin
fit_mixin.DatasetDict = DatasetDict

DATA_DIR      = Path("data/GSEA/external_gene_data/store!")
GENE_FILE     = DATA_DIR / "rat_genes_consolidated.txt.gz"
PATHWAY_FILE  = DATA_DIR / "wikipathways_synonyms_Rattus_norvegicus.gmt.gz"
OUT_JSONL     = Path("train_pairs.jsonl")
BASE_MODEL    = "michiyasunaga/BioLinkBERT-large"
OUTPUT_FOLDER = "./output/model/biolinkbert-large-simcse-rat"
DEVICE        = "cuda" if torch.cuda.is_available() else "cpu"

BATCH_SIZE     = 64            # larger batch improves contrastive negatives
EPOCHS         = 25             # train longer with early stopping
LEARNING_RATE  = 3e-5          # standard SBERT fine-tune LR
WEIGHT_DECAY   = 0.01          # L2 regularization for AdamW
WARMUP_RATIO   = 0.1           # 10% of total steps for warmup
MAX_GRAD_NORM  = 1.0           # gradient clipping
EVAL_RATIO     = 0.1           # 10% of data for validation
PATIENCE       = 3             # early-stop after 3 evals without gain
SEED           = 42            # fixed seed for reproducibility


def add_pairs(pairs, texts):
    """Add all unique 2-combinations from texts as positive pairs."""
    texts = [t.strip() for t in texts if t and str(t).strip()]
    for a, b in itertools.combinations(set(texts), 2):
        pairs.append({"text1": a, "text2": b})

def build_and_write_pairs():
    """Build positive pairs from genes + pathways, shuffle, and dump to JSONL."""
    pairs = []
    # Gene IDs ↔ name ↔ description
    with gzip.open(GENE_FILE, "rt") as fh:
        rdr = csv.DictReader(fh)
        for row in rdr:
            add_pairs(pairs, [row["Gene stable ID"],
                              row["Gene name"],
                              row["Gene description"]])
    # Pathway synonyms inside [ … ]
    bracket_re = re.compile(r"\[([^\]]+)\]")
    with gzip.open(PATHWAY_FILE, "rt") as fh:
        for line in fh:
            if not line.strip(): continue
            pathway = re.sub(r"\s+", " ", line.split("\t")[0]).strip()
            for grp in bracket_re.findall(line):
                syns = [g.strip() for g in grp.split(",") if g.strip()]
                add_pairs(pairs, syns)
                for s in syns:
                    pairs.append({"text1": pathway, "text2": s})
    random.shuffle(pairs)
    with OUT_JSONL.open("w") as out:
        for ex in pairs:
            out.write(json.dumps(ex, ensure_ascii=False) + "\n")
    print(f"✅ Wrote {len(pairs):,} pairs → {OUT_JSONL} "
          f"({OUT_JSONL.stat().st_size/1e6:.2f} MB)")

def prepare_training():
    """Load JSONL into InputExample list, build model + loader + MNR loss."""
    examples = [
        InputExample(texts=[d["text1"], d["text2"]])
        for d in map(json.loads, OUT_JSONL.open())
    ]
    loader = DataLoader(examples,
                        batch_size=BATCH_SIZE,
                        shuffle=True,
                        drop_last=True)

    word_model = models.Transformer(BASE_MODEL, max_seq_length=128)
    word_model.auto_model.gradient_checkpointing_enable()
    pool_model = models.Pooling(
        word_model.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True
    )
    print(DEVICE)
    model = SentenceTransformer(modules=[word_model, pool_model],
                                device=DEVICE)

    loss = losses.MultipleNegativesRankingLoss(model)

    return model, loader, loss

def train(model, loader, loss):
    total_steps = len(loader) * EPOCHS
    warmup_steps = int(WARMUP_RATIO * total_steps)
    model.fit(
        train_objectives=[(loader, loss)],
        epochs=EPOCHS,
        optimizer_params={"lr": LEARNING_RATE},
        warmup_steps=warmup_steps,
        use_amp=True,
        output_path=OUTPUT_FOLDER,
        show_progress_bar=True
    )



if __name__ == "__main__":
    build_and_write_pairs()
    model, loader, loss = prepare_training()
    train(model, loader, loss)


✅ Wrote 121,533 pairs → train_pairs.jsonl (7.15 MB)
cuda


                                                                     

Step,Training Loss
500,4.3779
1000,3.7658
1500,2.9926
2000,2.6905
2500,2.5404
3000,2.4603
3500,2.3984
4000,2.3122
4500,2.2417
5000,2.1948


✅ Training done in 141.1 min


In [1]:
from sentence_transformers import SentenceTransformer, models, util
import torch


device = "cuda" if torch.cuda.is_available() else "cpu"


BASE_MODEL = "michiyasunaga/BioLinkBERT-large"

# a) Transformer module loads the exact HF weights
word_model = models.Transformer(
    model_name_or_path=BASE_MODEL,
    max_seq_length=128
)

# b) Mean pooling over token embeddings
pool_model = models.Pooling(
    word_model.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

# c) Combine into a SentenceTransformer
base = SentenceTransformer(
    modules=[word_model, pool_model],
    device=device
)

FT_FOLDER = "./output/model/biolinkbert-large-simcse-rat"
ft = SentenceTransformer(FT_FOLDER, device=device)

test_pairs = [
    # (text1, text2, expected_cosine)
    ("Alx4",       "ALX homeobox 4",              0.90),
    ("Dgkh",       "diacylglycerol kinase, eta",  0.85),
    ("Prdm11",     "PR/SET domain 11",            0.87),
    ("Spata3",     "spermatogenesis associated 3",0.89),
    ("ENSRNOG00000015159", "solute carrier family 9 member A3", 0.88),
    ("Abcc1",     "Irinotecan pathway",           0.55),
    ("Arrb2",     "Wnt signaling pathway",        0.50),
    ("Ppara",     "Nuclear receptors",            0.52),
    ("gamma-PAK", "Regulation of actin cytoskeleton", 0.48),
    ("Alx4",   "glutamate decarboxylase 1", 0.10),
    ("Dgkh",   "Arrb2",                     0.05),
    ("Prdm11", "Irinotecan pathway",        0.08),
]

print(f"{'Text 1':25s} ↔ {'Text 2':30s}  {'Exp':>5s}  {'Base':>6s}  {'FT':>6s}   Result")
print("-" * 85)

worked = 0
total  = len(test_pairs)

for text1, text2, exp in test_pairs:
    emb_b1 = base.encode(text1, convert_to_tensor=True)
    emb_b2 = base.encode(text2, convert_to_tensor=True)
    emb_f1 = ft.encode(text1,   convert_to_tensor=True)
    emb_f2 = ft.encode(text2,   convert_to_tensor=True)

    sim_base = util.cos_sim(emb_b1, emb_b2).item()
    sim_ft   = util.cos_sim(emb_f1, emb_f2).item()

    dist_base = abs(sim_base - exp)
    dist_ft   = abs(sim_ft   - exp)
    result = "WORKED ✅" if dist_ft < dist_base else "DIDN’T WORK ❌"
    if dist_ft < dist_base:
        worked += 1

    print(
        f"{text1:25s} ↔ {text2:30s}  "
        f"{exp:5.2f}  {sim_base:6.3f}  {sim_ft:6.3f}   {result}"
    )

# Overall success rate
pct = worked / total * 100
print("\nSummary:")
print(f"  {worked}/{total} pairs closer to fine-tuned → {pct:.1f}% worked")


  from .autonotebook import tqdm as notebook_tqdm


Text 1                    ↔ Text 2                            Exp    Base      FT   Result
-------------------------------------------------------------------------------------
Alx4                      ↔ ALX homeobox 4                   0.90   0.655   0.949   WORKED ✅
Dgkh                      ↔ diacylglycerol kinase, eta       0.85   0.804   0.751   DIDN’T WORK ❌
Prdm11                    ↔ PR/SET domain 11                 0.87   0.891   0.916   DIDN’T WORK ❌
Spata3                    ↔ spermatogenesis associated 3     0.89   0.707   0.892   WORKED ✅
ENSRNOG00000015159        ↔ solute carrier family 9 member A3   0.88   0.737   0.580   DIDN’T WORK ❌
Abcc1                     ↔ Irinotecan pathway               0.55   0.551   0.591   DIDN’T WORK ❌
Arrb2                     ↔ Wnt signaling pathway            0.50   0.849   0.425   WORKED ✅
Ppara                     ↔ Nuclear receptors                0.52   0.810   0.569   WORKED ✅
gamma-PAK                 ↔ Regulation of actin cytoskel

In [None]:
from huggingface_hub import HfApi

api = HfApi(token=os.getenv("HF_TOKEN"))
api.upload_folder(
    folder_path="./output/model/biolinkbert-large-simcse-rat",
    repo_id="mghuibregtse/biolinkbert-large-simcse-rat",
    repo_type="model",
)
