In [None]:

import csv
import json
import random
from collections import defaultdict

INPUT_CSV = "data/us_gaap_2025_verified_subcategory_path.csv"
OUTPUT_JSONL = "data/us_gaap_triplet_training_data.jsonl"
MAX_TRIPLETS = 5000

def generate_triplet_dataset():
    with open(INPUT_CSV, "r") as f:
        reader = csv.DictReader(f)
        rows = [row for row in reader if row["statement_type"]]

    # Group rows by statement type
    tag_rows_by_statement = defaultdict(list)
    for row in rows:
        tag_rows_by_statement[row["statement_type"]].append(row)

    triplets = []

    for statement_type, tag_rows in tag_rows_by_statement.items():
        # Partition by balance and period_type
        debit_tags = [r for r in tag_rows if r["balance"] == "debit"]
        credit_tags = [r for r in tag_rows if r["balance"] == "credit"]
        instant_tags = [r for r in tag_rows if r["period_type"] == "instant"]
        duration_tags = [r for r in tag_rows if r["period_type"] == "duration"]

        for anchor_row in tag_rows:
            anchor = anchor_row["description"]
            anchor_tag = anchor_row["tag"]
            anchor_balance = anchor_row["balance"]
            anchor_period = anchor_row["period_type"]

            # Positive: same statement, different tag
            positives = [r for r in tag_rows if r["tag"] != anchor_tag]
            if not positives:
                continue
            positive_row = random.choice(positives)

            # Negative: opposite balance or different period_type
            if anchor_balance == "debit":
                candidates = credit_tags
            elif anchor_balance == "credit":
                candidates = debit_tags
            else:
                candidates = []

            if anchor_period == "instant":
                candidates += duration_tags
            elif anchor_period == "duration":
                candidates += instant_tags

            candidates = [r for r in candidates if r["tag"] != anchor_tag]
            if not candidates:
                continue
            negative_row = random.choice(candidates)

            triplets.append({
                "anchor": anchor,
                "positive": positive_row["description"],
                "negative": negative_row["description"],
                "positive_tag": positive_row["tag"],
                "negative_tag": negative_row["tag"]
            })

            if len(triplets) >= MAX_TRIPLETS:
                break
        if len(triplets) >= MAX_TRIPLETS:
            break

    with open(OUTPUT_JSONL, "w") as f:
        for row in triplets:
            f.write(json.dumps(row) + "\n")

    print(f"Generated {len(triplets)} triplets to {OUTPUT_JSONL}")

if __name__ == "__main__":
    generate_triplet_dataset()
