In [None]:
import pandas as pd
from db import DB
from utils import generate_us_gaap_description 

db = DB()

queries = {
    "train": """
        SELECT
            v.text AS input_text,
            COALESCE(t.balance_type_id, 0) AS balance_type_id,
            COALESCE(t.period_type_id, 0) AS period_type_id,
            GROUP_CONCAT(DISTINCT m.ofss_category_id ORDER BY m.ofss_category_id) AS labels
        FROM us_gaap_tag_description_variation v
        JOIN us_gaap_tag t ON t.id = v.us_gaap_tag_id
        JOIN us_gaap_tag_ofss_category m ON m.us_gaap_tag_id = t.id
        GROUP BY v.text, t.balance_type_id, t.period_type_id
    """,
    "val": """
        SELECT
            t.name AS input_text,
            COALESCE(t.balance_type_id, 0) AS balance_type_id,
            COALESCE(t.period_type_id, 0) AS period_type_id,
            GROUP_CONCAT(DISTINCT m.ofss_category_id ORDER BY m.ofss_category_id) AS labels
        FROM us_gaap_tag t
        JOIN us_gaap_tag_ofss_category m ON m.us_gaap_tag_id = t.id
        GROUP BY t.name, t.balance_type_id, t.period_type_id
    """
}

def build_dataset(name: str, query: str):
    df = db.get(query, ["input_text", "balance_type_id", "period_type_id", "labels"])
    
    if name == "val":
        # Apply generate_us_gaap_description to input_text for validation dataset
        df["input_text"] = df["input_text"].apply(generate_us_gaap_description)

    # Process the labels
    df["labels"] = df["labels"].apply(lambda s: [int(x) for x in s.split(",")] if s else [])

    # Limit to a maximum of 2 labels per row
    df["labels"] = df["labels"].apply(lambda x: x[:2])
    
    # Save as JSONL
    df.to_json(f"data/{name}.jsonl", orient="records", lines=True)
    print(f"{name}.jsonl saved with {len(df)} rows.")

if __name__ == "__main__":
    for split, sql in queries.items():
        build_dataset(split, sql)
