In [None]:

import json
import random
import torch
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, InputExample
from sentence_transformers.evaluation import InformationRetrievalEvaluator

# === Config ===
MODEL_NAME = "BAAI/bge-large-en-v1.5"
TRIPLET_FILE = "data/us_gaap_triplet_training_data.jsonl"
OUTPUT_PATH = "fine_tuned_gaap_model"
EVAL_METRICS_PATH = "eval_metrics.csv"
BATCH_SIZE = 16
EPOCHS = 10
EVAL_STEPS = 500
EARLY_STOPPING_PATIENCE = 2
MAX_TRIPLETS = 5000
EVAL_SAMPLES = 1000

device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# === Load Triplet Data ===
with open(TRIPLET_FILE, "r") as f:
    all_triplets = [json.loads(l) for l in f]

random.shuffle(all_triplets)
eval_triplets = all_triplets[:EVAL_SAMPLES]
train_triplets = all_triplets[EVAL_SAMPLES:MAX_TRIPLETS]

# === Build IR Evaluator with early stopping
queries = {}
corpus = {}
relevant_docs = {}

for i, triplet in enumerate(eval_triplets):
    qid = f"q{i}"
    pid = f"{qid}_pos"
    nid = f"{qid}_neg"
    queries[qid] = triplet["anchor"]
    corpus[pid] = triplet["positive"]
    corpus[nid] = triplet["negative"]
    relevant_docs[qid] = set([pid])

evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    show_progress_bar=True,
    name="gaap-ir-eval",
    write_csv=EVAL_METRICS_PATH,
)

# === Build training data ===
train_examples = [
    InputExample(texts=[t["anchor"], t["positive"], t["negative"]])
    for t in train_triplets
]

train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)
model = SentenceTransformer(MODEL_NAME, device=device)
train_loss = losses.TripletLoss(model)

# === Train ===
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=evaluator,
    epochs=EPOCHS,
    evaluation_steps=EVAL_STEPS,
    output_path=OUTPUT_PATH,
    show_progress_bar=True
)

print(f"✅ Fine-tuned model saved to: {OUTPUT_PATH}")
print(f"📊 Evaluation metrics saved to: {EVAL_METRICS_PATH}")
