In [None]:
import pandas as pd
from db import DB
from tqdm import tqdm
from models.pytorch.us_gaap_alignment.inference import find_closest_match
from models.pytorch.us_gaap_alignment import UsGaapAlignmentModel
from utils.pytorch import seed_everything, get_device

device = get_device()
print(f"Using device: {device}")

# Load model from checkpoint
model = UsGaapAlignmentModel.from_pretrained(
    "data/pretrained",
    device=device
)

db = DB()

# Query unmapped concepts
query = """
SELECT
    c.id,
    c.name,
    ct.concept_type,
    bt.balance,
    pt.period_type
FROM us_gaap_concept c
JOIN us_gaap_concept_type ct ON ct.id = c.concept_type_id
LEFT JOIN us_gaap_balance_type bt ON bt.id = c.balance_type_id
LEFT JOIN us_gaap_period_type pt ON pt.id = c.period_type_id
LEFT JOIN us_gaap_concept_ofss_category m ON m.us_gaap_concept_id = c.id
WHERE m.ofss_category_id IS NULL
"""

df = db.get(query, ["id", "name", "concept_type", "balance", "period_type"])
print(f"Found {len(df)} unmapped concepts.")

# Load reference dataset
dataset_path = "data/us_gaap_concepts_with_variations_and_embeddings.jsonl"

for row in tqdm(df.itertuples(index=False), total=len(df), desc="Mapping Concepts"):
    result = find_closest_match(
        row.name,
        model=model,
        concept_type=row.concept_type,
        balance_type=row.balance,
        period_type=row.period_type,
        dataset_path=dataset_path,
        top_k=1,
        device=device
    )

    if not result:
        continue

    closest = result[0]
    us_gaap_concept_id = row.id

    for ofss_category_id in closest["ofss_category_ids"]:
        db.upsert_entity(
            table_name="us_gaap_concept_ofss_category",
            field_dict={
                "us_gaap_concept_id": us_gaap_concept_id,
                "ofss_category_id": ofss_category_id,
                "is_manually_mapped": 0,
            },
            unique_fields=["us_gaap_concept_id", "ofss_category_id"]
        )

    if "statement_type_ids" in closest:
        for stid in closest["statement_type_ids"]:
            db.upsert_entity(
                table_name="us_gaap_concept_statement_type",
                field_dict={
                    "us_gaap_concept_id": us_gaap_concept_id,
                    "us_gaap_statement_type_id": stid,
                    "is_manually_mapped": 0,
                },
                unique_fields=["us_gaap_concept_id", "us_gaap_statement_type_id"]
            )

print("Inference and upsert complete.")
