In [None]:
"""
MIXED ANCHOR GENERATION (Randomized at output time)

For each row, we create ONE anchor:
 - EITHER just the description
 - OR "statement_type.lower() description"

No duplication ahead of time, no altered anchor tags.
Balance/period logic for negatives is preserved.
Subcategories are ignored.
"""

import csv
import json
import random
from collections import defaultdict
from sentence_transformers import SentenceTransformer
import torch
from numpy import dot
from numpy.linalg import norm
from difflib import SequenceMatcher


device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# ========== CONFIG ==========
INPUT_CSV = "data/us_gaap_2025_verified_subcategory_path.csv"
OUTPUT_JSONL = "data/us_gaap_triplet_training_data_filtered.jsonl"

MAX_TRIPLETS = 200000
MIN_MARGIN = 0.05

POSITIVE_SIM_THRESHOLD = 0.75
POSITIVE_ATTEMPTS_PER_ANCHOR = 30
NEGATIVE_ATTEMPTS_PER_POSITIVE = 30
NEGATIVE_CROSS_STATEMENT_RATIO = 0.5
MAX_NEG_MATCHES_PER_POS = 1

PARTIAL_SAVE_INTERVAL = 10000
NEAR_DUPLICATE_ANCHOR_POS_THRESHOLD = 0.90

model = SentenceTransformer("BAAI/bge-large-en-v1.5", device=device)


def cosine(u, v):
    return dot(u, v) / (norm(u) * norm(v))


def text_too_similar(txt_a, txt_b, threshold=0.9):
    ratio = SequenceMatcher(None, txt_a.lower(), txt_b.lower()).ratio()
    return ratio > threshold

def get_full_description(row):
    return f"{row['subcategory_path']} {row['description']}".strip()

vec_cache = {}
def get_vec(text):
    if text not in vec_cache:
        vec_cache[text] = model.encode(text)
    return vec_cache[text]

# ==============================
print("Loading CSV...")
with open(INPUT_CSV, "r") as f:
    reader = csv.DictReader(f)
    rows = [r for r in reader if r.get("statement_type")]
    all_rows = rows  # For cross-statement negative sampling

by_statement = defaultdict(list)
for r in rows:
    by_statement[r["statement_type"]].append(r)

triplets = []
triplet_count = 0
triplet_attempts = 0
used_combos = set()  # (anchor_tag, pos_tag, neg_tag)


def partial_save():
    partial_file = "data/us_gaap_triplet_training_data_partial.jsonl"
    with open(partial_file, "w") as pf:
        for item in triplets:
            pf.write(json.dumps(item) + "\n")
    print(f"[PARTIAL SAVE] Wrote {len(triplets)} triplets to {partial_file}")

# ========== GENERATE TRIPLETS ==========

for statement_type, tag_rows in by_statement.items():
    random.shuffle(tag_rows)

    for anchor_row in tag_rows:
        if triplet_count >= MAX_TRIPLETS:
            break

        anchor_tag = anchor_row["tag"]
        anchor_balance = anchor_row.get("balance", "")
        anchor_period = anchor_row.get("period_type", "")
        anchor_desc = anchor_row["description"]
        anchor_subcategory_path = anchor_row['subcategory_path']
        anchor_full_desc = get_full_description(anchor_row)

        anchor_vec = get_vec(anchor_full_desc)

        # Positive candidates (same statement_type, balance, period, exclude anchor)
        positives = [r for r in tag_rows if r["balance"] == anchor_balance and r["period_type"] == anchor_period and r["tag"] != anchor_tag]
        if not positives:
            continue

        positive_sims = []
        for candidate in positives:
            candidate_vec = get_vec(get_full_description(candidate))
            candidate_sim = cosine(candidate_vec, anchor_vec)
            positive_sims.append((candidate, candidate_sim))

            if candidate_sim >= POSITIVE_SIM_THRESHOLD:
                positive_sims.append((candidate, candidate_sim))

        # Sort by similarity descending and take top N
        positive_candidates = [
            c for c, _ in sorted(positive_sims, key=lambda x: x[1], reverse=True)
        ][:POSITIVE_ATTEMPTS_PER_ANCHOR]

        for p_row in positive_candidates:
            if triplet_count >= MAX_TRIPLETS:
                break

            p_full_desc = get_full_description(p_row)
            if text_too_similar(anchor_full_desc, p_full_desc, NEAR_DUPLICATE_ANCHOR_POS_THRESHOLD):
                continue

            p_vec = get_vec(p_full_desc)
            sim_pos = cosine(anchor_vec, p_vec)

            use_cross_statement = random.random() < NEGATIVE_CROSS_STATEMENT_RATIO

            if use_cross_statement:
                negatives = [
                    r for r in all_rows
                    if r["tag"] not in (anchor_tag, p_row["tag"])
                    and r.get("statement_type") != statement_type
                    and (
                        (anchor_balance and r.get("balance") and r["balance"] != anchor_balance)
                        or (anchor_period and r.get("period_type") and r["period_type"] != anchor_period)
                    )
                ]
            else:
                negatives = [
                    r for r in tag_rows
                    if r["tag"] != p_row["tag"]
                    and (
                        (anchor_balance and r.get("balance") and r["balance"] != anchor_balance)
                        or (anchor_period and r.get("period_type") and r["period_type"] != anchor_period)
                    )
                ]

            if not negatives:
                continue

            negative_sims = []
            for candidate in negatives:
                candidate_vec = get_vec(get_full_description(candidate))
                candidate_sim = cosine(candidate_vec, anchor_vec)
                negative_sims.append((candidate, candidate_sim))

            # Sort by similarity descending and take bottom N
            negative_candidates = [
                c for c, _ in sorted(negative_sims, key=lambda x: x[1], reverse=False)
            ][:NEGATIVE_ATTEMPTS_PER_POSITIVE]

            valid_neg_found = 0
            for n_row in negative_candidates:
                if triplet_count >= MAX_TRIPLETS:
                    break

                combo_key = (anchor_tag, p_row["tag"], n_row["tag"])
                if combo_key in used_combos:
                    continue

                n_full_desc = get_full_description(n_row)
                n_vec = get_vec(n_full_desc)
                sim_neg = cosine(anchor_vec, n_vec)
                triplet_attempts += 1

                if sim_pos >= sim_neg + MIN_MARGIN:
                    # Randomly decide if we add statement type (for the anchor_vec)
                    if random.random() < 0.5:
                        anchor_text = f"{statement_type.lower()} {anchor_desc}"
                    else:
                        anchor_text = anchor_desc

                    triplets.append({
                        "anchor": anchor_text,
                        "positive": p_full_desc,
                        "negative": n_full_desc,
                        "positive_tag": p_row["tag"],
                        "negative_tag": n_row["tag"]
                    })
                    used_combos.add(combo_key)
                    triplet_count += 1
                    valid_neg_found += 1

                    if triplet_count % PARTIAL_SAVE_INTERVAL == 0:
                        partial_save()

                    if valid_neg_found >= MAX_NEG_MATCHES_PER_POS:
                        break

# ========== SAVE FINAL ==========
with open(OUTPUT_JSONL, "w") as out_f:
    for row in triplets:
        out_f.write(json.dumps(row) + "\n")

print(f"✅ Saved {len(triplets)} triplets to {OUTPUT_JSONL}")
print(f"🧠 Total triplet candidates checked: {triplet_attempts}")
