In [None]:
# Updated version with progress logging during triplet generation
from sentence_transformers import SentenceTransformer
import csv
import json
import random
from collections import defaultdict
from numpy import dot
from numpy.linalg import norm

INPUT_CSV = "data/us_gaap_2025_verified_subcategory_path.csv"
OUTPUT_JSONL = "data/us_gaap_triplet_training_data_filtered.jsonl"
MAX_TRIPLETS = 5000
MIN_MARGIN = 0.05

model = SentenceTransformer("BAAI/bge-large-en-v1.5")

def cosine(u, v):
    return dot(u, v) / (norm(u) * norm(v))

with open(INPUT_CSV, "r") as f:
    reader = csv.DictReader(f)
    rows = [row for row in reader if row["statement_type"]]

tag_rows_by_statement = defaultdict(list)
for row in rows:
    tag_rows_by_statement[row["statement_type"]].append(row)

triplets = []
triplet_attempts = 0

for statement_type, tag_rows in tag_rows_by_statement.items():
    debit_tags = [r for r in tag_rows if r["balance"] == "debit"]
    credit_tags = [r for r in tag_rows if r["balance"] == "credit"]
    instant_tags = [r for r in tag_rows if r["period_type"] == "instant"]
    duration_tags = [r for r in tag_rows if r["period_type"] == "duration"]

    for anchor_row in tag_rows:
        if len(triplets) >= MAX_TRIPLETS:
            break

        anchor = anchor_row["description"]
        anchor_tag = anchor_row["tag"]
        anchor_balance = anchor_row["balance"]
        anchor_period = anchor_row["period_type"]

        positives = [r for r in tag_rows if r["tag"] != anchor_tag]
        if not positives:
            continue
        positive_row = random.choice(positives)

        if anchor_balance == "debit":
            candidates = credit_tags
        elif anchor_balance == "credit":
            candidates = debit_tags
        else:
            candidates = []

        if anchor_period == "instant":
            candidates += duration_tags
        elif anchor_period == "duration":
            candidates += instant_tags

        candidates = [r for r in candidates if r["tag"] != anchor_tag]
        if not candidates:
            continue

        negative_row = random.choice(candidates)

        a_text = anchor
        p_text = positive_row["description"]
        n_text = negative_row["description"]

        a_vec, p_vec, n_vec = model.encode([a_text, p_text, n_text])
        sim_pos = cosine(a_vec, p_vec)
        sim_neg = cosine(a_vec, n_vec)
        triplet_attempts += 1

        if sim_pos > sim_neg + MIN_MARGIN:
            triplets.append({
                "anchor": a_text,
                "positive": p_text,
                "negative": n_text,
                "positive_tag": positive_row["tag"],
                "negative_tag": negative_row["tag"]
            })

        if len(triplets) % 100 == 0:
            print(f"Triplets saved: {len(triplets)} / {MAX_TRIPLETS} "
                  f"(attempts: {triplet_attempts})")

    if len(triplets) >= MAX_TRIPLETS:
        break

with open(OUTPUT_JSONL, "w") as f:
    for row in triplets:
        f.write(json.dumps(row) + "\n")

print(f"✅ Saved {len(triplets)} filtered triplets to {OUTPUT_JSONL}")
print(f"🧠 Total triplet candidates checked: {triplet_attempts}")
