In [None]:
"""
RANDOM TRIPLET GENERATION WITH BALANCE/PERIOD CHECK FOR NEGATIVES
- For each anchor in the same statement_type:
  1) Randomly pick positives (POSITIVE_ATTEMPTS_PER_ANCHOR tries).
     - If near-duplicate with anchor, skip
  2) For each (anchor, positive), build a negative list:
     - If anchor has balance => pick from different balance
     - Else if anchor has period => pick from different period
     - Then pick random negatives from that set (NEGATIVE_ATTEMPTS_PER_POSITIVE)
  3) If sim_pos >= sim_neg + MIN_MARGIN => valid
     - Up to 2 valid negatives per anchor–positive
- De-dup combos via used_combos
- Partial saves
"""

import csv
import json
import random
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from numpy import dot
from numpy.linalg import norm
from difflib import SequenceMatcher

######################## CONFIG ########################
INPUT_CSV = "data/us_gaap_2025_verified_subcategory_path.csv"
OUTPUT_JSONL = "data/us_gaap_triplet_training_data_filtered.jsonl"

MAX_TRIPLETS = 100000
MIN_MARGIN = 0.05

POSITIVE_ATTEMPTS_PER_ANCHOR = 20
NEGATIVE_ATTEMPTS_PER_POSITIVE = 20
MAX_NEG_MATCHES_PER_POS = 2  # up to 2 negatives per anchor–positive

PARTIAL_SAVE_INTERVAL = 100
NEAR_DUPLICATE_ANCHOR_POS_THRESHOLD = 0.90  # sequenceMatcher ratio

model = SentenceTransformer("BAAI/bge-large-en-v1.5")

def cosine(u, v):
    return dot(u, v) / (norm(u) * norm(v))

def text_too_similar(txt_a, txt_b, threshold=0.9):
    ratio = SequenceMatcher(None, txt_a.lower(), txt_b.lower()).ratio()
    return ratio > threshold

############################################

with open(INPUT_CSV, "r") as f:
    reader = csv.DictReader(f)
    rows = [r for r in reader if r.get("statement_type")]

# Group by statement_type => random picks from same statement
by_statement = defaultdict(list)
for r in rows:
    by_statement[r["statement_type"]].append(r)

triplets = []
triplet_count = 0
triplet_attempts = 0
used_combos = set()  # (anchor_tag, pos_tag, neg_tag)

def partial_save():
    partial_file = "data/us_gaap_triplet_training_data_partial.jsonl"
    with open(partial_file, "w") as pf:
        for item in triplets:
            pf.write(json.dumps(item) + "\n")
    print(f"[PARTIAL SAVE] Wrote {len(triplets)} triplets to {partial_file}")

for statement_type, tag_rows in by_statement.items():
    # Shuffle anchor rows for randomness
    random.shuffle(tag_rows)

    for anchor_row in tag_rows:
        if triplet_count >= MAX_TRIPLETS:
            break

        anchor_tag = anchor_row["tag"]
        anchor_text = anchor_row["description"]
        anchor_balance = anchor_row.get("balance", "")
        anchor_period = anchor_row.get("period_type", "")
        anchor_vec = model.encode(anchor_text)

        # 1) Randomly pick positives
        # (same statement, excluding anchor)
        positives = [r for r in tag_rows if r["tag"] != anchor_tag]
        if not positives:
            continue

        # We'll pick random positives (up to N attempts)
        positive_candidates = random.sample(
            positives, min(POSITIVE_ATTEMPTS_PER_ANCHOR, len(positives))
        )

        for p_row in positive_candidates:
            if triplet_count >= MAX_TRIPLETS:
                break

            p_text = p_row["description"]
            if text_too_similar(anchor_text, p_text, NEAR_DUPLICATE_ANCHOR_POS_THRESHOLD):
                # near-duplicate anchor vs positive
                continue

            p_vec = model.encode(p_text)
            sim_pos = cosine(anchor_vec, p_vec)

            # 2) Build negative list using anchor_balance or anchor_period
            negatives = []
            # If anchor has a balance, pick from different balance
            if anchor_balance:
                negatives = [
                    r for r in tag_rows
                    if r["tag"] not in (anchor_tag, p_row["tag"])
                       and r.get("balance", "")
                       and r["balance"] != anchor_balance
                ]
            # Otherwise if anchor has period, pick from different period
            if not negatives and anchor_period:
                negatives = [
                    r for r in tag_rows
                    if r["tag"] not in (anchor_tag, p_row["tag"])
                       and r.get("period_type", "")
                       and r["period_type"] != anchor_period
                ]

            if not negatives:
                continue

            # Randomly pick negatives
            negative_candidates = random.sample(
                negatives, min(NEGATIVE_ATTEMPTS_PER_POSITIVE, len(negatives))
            )

            valid_neg_found = 0
            for n_row in negative_candidates:
                if triplet_count >= MAX_TRIPLETS:
                    break

                combo_key = (anchor_tag, p_row["tag"], n_row["tag"])
                if combo_key in used_combos:
                    continue

                n_text = n_row["description"]
                n_vec = model.encode(n_text)
                sim_neg = cosine(anchor_vec, n_vec)
                triplet_attempts += 1

                # margin check
                if sim_pos >= sim_neg + MIN_MARGIN:
                    # valid
                    triplets.append({
                        "anchor": anchor_text,
                        "positive": p_text,
                        "negative": n_text,
                        "positive_tag": p_row["tag"],
                        "negative_tag": n_row["tag"]
                    })
                    used_combos.add(combo_key)
                    triplet_count += 1
                    valid_neg_found += 1

                    # partial save
                    if triplet_count % PARTIAL_SAVE_INTERVAL == 0:
                        partial_save()

                    # If we found 2 negative matches for this anchor–positive, stop
                    if valid_neg_found >= MAX_NEG_MATCHES_PER_POS:
                        break
            # end for negative_candidates
        # end for positive_candidates
    # end for anchor_row
# end for statement_type

# Final save
with open(OUTPUT_JSONL, "w") as out_f:
    for row in triplets:
        out_f.write(json.dumps(row) + "\n")

print(f"✅ Saved {len(triplets)} triplets to {OUTPUT_JSONL}")
print(f"🧠 Total triplet candidates checked: {triplet_attempts}")
