In [None]:
import os
import re
import csv
import xml.etree.ElementTree as ET
from collections import defaultdict

# https://www.fasb.org/page/detail?pageId=/projects/FASB-Taxonomies/2025-gaap-financial-reporting-taxonomy.html
# 2025 Taxonomy downloads: https://xbrl.fasb.org/us-gaap/2025/us-gaap-2025.zip

# === CONFIG ===
TAXONOMY_DIR = "data/us-gaap-2025/"
ELTS_XSD = os.path.join(TAXONOMY_DIR, "elts", "us-gaap-2025.xsd")
STM_DIR = os.path.join(TAXONOMY_DIR, "stm")

OUTPUT_PATH = "data/WITH_TAXONOMY_HIERARCHY_us_gaap_2025_with_all_statements_and_hierarchy.csv"

FILENAME_STATEMENT_MAP = {
    "scf": "Cash Flow Statement",
    "soi": "Income Statement",
    "sfp": "Balance Sheet",
    "sheci": "Equity Statement",
    "soc": "Comprehensive Income"
}

XBRLI_NS = "http://www.xbrl.org/2003/instance"
BALANCE_KEY = f"{{{XBRLI_NS}}}balance"
PERIOD_TYPE_KEY = f"{{{XBRLI_NS}}}periodType"


def generate_description(tag_name):
    return re.sub(
        r'(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])',
        ' ',
        tag_name
    ).lower()


# === STEP 1: TAG METADATA ===
tag_metadata = {}
tree = ET.parse(ELTS_XSD)
root = tree.getroot()
for el in root.findall(".//{http://www.w3.org/2001/XMLSchema}element"):
    name = el.attrib.get("name")
    if not name or el.attrib.get("abstract") == "true":
        continue
    tag_metadata[name] = {
        "balance": el.attrib.get(BALANCE_KEY, ""),
        "period_type": el.attrib.get(PERIOD_TYPE_KEY, ""),
        "statements": set()
    }

# === STEP 2: STATEMENT TYPE MAPPING ===
for file in os.listdir(STM_DIR):
    if not file.endswith(".xml") or "-pre-" not in file:
        continue
    parts = file.split("-")
    if len(parts) < 4:
        continue
    stmt_key = parts[3]
    inferred_statement = FILENAME_STATEMENT_MAP.get(stmt_key)
    if not inferred_statement:
        continue

    tree = ET.parse(os.path.join(STM_DIR, file))
    root = tree.getroot()
    for loc in root.findall(".//{http://www.xbrl.org/2003/linkbase}loc"):
        href = loc.attrib.get("{http://www.w3.org/1999/xlink}href", "")
        tag = href.split("#")[-1]
        if tag.startswith("us-gaap_"):
            tag = tag.replace("us-gaap_", "")
        if tag in tag_metadata:
            tag_metadata[tag]["statements"].add(inferred_statement)

# === STEP 3: BUILD BASE OUTPUT ===
final_output = []
for tag, meta in tag_metadata.items():
    for stmt in meta["statements"]:
        final_output.append({
            "tag": tag,
            "statement_type": stmt,
            "balance": meta["balance"],
            "period_type": meta["period_type"],
            "description": generate_description(tag),
            "subcategory_path": ""
        })

# === STEP 4: EXTRACT TAXONOMY HIERARCHY PATHS ===
def extract_tag_hierarchy_paths(pre_files):
    tag_paths = defaultdict(set)

    for path in pre_files:
        try:
            tree = ET.parse(path)
            root = tree.getroot()

            loc_map = {}
            arcs_by_from = defaultdict(list)

            for loc in root.findall(".//{http://www.xbrl.org/2003/linkbase}loc"):
                loc_id = loc.attrib.get("{http://www.w3.org/1999/xlink}label")
                href = loc.attrib.get("{http://www.w3.org/1999/xlink}href", "")
                if loc_id and "us-gaap_" in href:
                    tag = href.split("#")[-1]
                    loc_map[loc_id] = tag

            for arc in root.findall(".//{http://www.xbrl.org/2003/linkbase}presentationArc"):
                from_id = arc.attrib.get("{http://www.w3.org/1999/xlink}from")
                to_id = arc.attrib.get("{http://www.w3.org/1999/xlink}to")
                arcs_by_from[from_id].append(to_id)

            all_from = set(arcs_by_from.keys())
            all_to = {to for children in arcs_by_from.values() for to in children}
            root_ids = all_from - all_to

            def walk(current_id, path):
                tag = loc_map.get(current_id)
                current_path = path + [tag] if tag else path
                for child_id in arcs_by_from.get(current_id, []):
                    child_tag = loc_map.get(child_id)
                    if child_tag and child_tag.startswith("us-gaap_"):
                        clean_tag = child_tag.replace("us-gaap_", "")
                        tag_paths[clean_tag].add(" > ".join(current_path + [clean_tag]))
                    walk(child_id, current_path)

            for root_id in root_ids:
                walk(root_id, [])

        except ET.ParseError:
            continue

    return tag_paths


# Gather all presentation files (stm + dis + others)
pre_files = []
for root, _, files in os.walk(TAXONOMY_DIR):
    for file in files:
        if "-pre-" in file and file.endswith(".xml"):
            pre_files.append(os.path.join(root, file))

tag_hierarchy = extract_tag_hierarchy_paths(pre_files)

# === STEP 5: ATTACH HIERARCHY TO TAGS ===
ALL_STATEMENTS = {
    "Cash Flow Statement",
    "Income Statement",
    "Balance Sheet",
    "Equity Statement",
    "Comprehensive Income"
}

final_filtered = []
for row in final_output:
    if row["statement_type"] in ALL_STATEMENTS:
        tag = row["tag"]
        row["taxonomy_hierarchy"] = (
            list(tag_hierarchy.get(tag))[0] if tag in tag_hierarchy else ""
        )
        final_filtered.append(row)

# === STEP 6: WRITE OUTPUT CSV ===
# os.makedirs("output", exist_ok=True)
with open(OUTPUT_PATH, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=[
        "tag", "statement_type", "balance", "period_type",
        "description", "subcategory_path", "taxonomy_hierarchy"
    ])
    writer.writeheader()
    writer.writerows(final_filtered)


# Append OFSS IDs

In [None]:
import csv
import json
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util

# === CONFIG ===
MODEL_NAME = "BAAI/bge-large-en-v1.5"
INPUT_CSV = "data/WITH_TAXONOMY_HIERARCHY_us_gaap_2025_with_all_statements_and_hierarchy.csv"
OFSS_JSON = "../shared/open_financial_statement_schema.json"
OUTPUT_CSV = "data/with_ofss_ids.csv"
SIMILARITY_THRESHOLD = 0.7

# === DEVICE SETUP ===
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

model = SentenceTransformer(MODEL_NAME, device=device)

# === FLATTEN NESTED OFSS MAP ===
def flatten_ofss(d, parent_key=""):
    flat = {}
    for k, v in d.items():
        new_key = f"{parent_key}/{k}" if parent_key else k
        if isinstance(v, dict):
            flat.update(flatten_ofss(v, new_key))
        else:
            flat[new_key] = v
    return flat

def normalize_stmt_prefix(stmt):
    stmt_prefix = stmt.lower().replace(" statement", "")
    return stmt_prefix

# Enhanced validation rules
def is_valid_mapping(stmt, balance, tag, mapped_name):
    stmt = stmt.lower()
    balance = str(balance).lower()
    mapped = mapped_name.lower()
    tag = tag.lower()

    # === Balance Sheet Validation ===
    # Prevent debit items mapping to liabilities/equity accounts
    if stmt == "balance sheet":
        if balance == "debit" and any(x in mapped for x in [
            "liabilities", "payable", "accrued", "common stock"
        ]):
            return False

        # Prevent credit items mapping to asset accounts
        if balance == "credit" and any(x in mapped for x in [
            "assets", "cash", "inventory", "receivable"
        ]):
            return False

        # Commercial paper should not be classified as shareholders' equity
        if "commercial paper" in tag and "shareholders' equity" in mapped:
            return False

        # Allowances should not map to liabilities
        if "allowance" in tag and "liabilities" in mapped:
            return False

        # Treasury stock should only map to equity contexts
        if "treasury stock" in tag and not (
            "shareholders' equity" in mapped or "equity statement" in mapped
        ):
            return False
        
        # Treasury stock must not be classified as assets or cash
        if "treasury stock" in tag and any(x in mapped for x in ["asset", "cash"]):
            return False
        
        # Restrictions on treasury stock should not be mapped to asset classes
        if "restriction" in tag and "treasury stock" in tag and any(
            x in mapped for x in ["cash", "asset", "receivable"]
        ):
            return False

        # Collateral should not map to goodwill (unrelated asset types)
        if "collateral" in tag and "goodwill" in mapped:
            return False

        # Derivative contracts should not map to minority interest
        if "derivatives and other contracts" in tag and "minority interest" in mapped:
            return False

        # Customer contracts should not map to property
        if "contract with customer" in tag and "property" in mapped:
            return False
        
        # Balance Sheet: retail deposits ≠ customer advances
        if "deposits retail" in tag and "customer advances" in mapped:
            return False

        # Convertible notes must retain convertible classification
        if "convertible" in tag and "convertible" not in mapped:
            return False
        
        # Discontinued operation assets ≠ generic 'other current assets'
        if "discontinued operation" in tag and "other current assets" in mapped:
            return False
        
        # Assets held for sale must not be grouped under 'other current assets'
        if "assets held for sale" in tag and "other current assets" in mapped:
            return False


    # === Cash Flow Statement Validation ===
    if stmt == "cash flow statement":
        investing_keywords = [
            "acquire", "investment", "purchase", "building", "property",
            "fund", "financing", "debt", "deposit", "debt reduction",
            "proceeds", "advance", "contribution"
        ]
        financing_keywords = ["distribution", "dividend", "redeemable securities"]

        # Block investing or financing terms mapping to operating activities
        if any(keyword in tag for keyword in investing_keywords) and "operating activities" in mapped:
            return False
        if any(keyword in tag for keyword in financing_keywords) and "operating activities" in mapped:
            return False

        # Subordinated debt repayment should not map to debt issuance
        if "repayments of subordinated debt" in tag and "debt issued" in mapped:
            return False

        # Date-related metadata should not map to cash flows
        if "original debt issuance date" in tag and "debt issued" in mapped:
            return False

        # Deposits should not be mapped to interest income
        if "deposit" in tag and "interest" in mapped:
            return False

        # Advances and contributions should not map to interest income
        if ("advance" in tag or "contribution" in tag) and "interest" in mapped:
            return False

        # Unrealized gain/loss should not map to investment purchases
        if "gainloss" in tag and "purchase of investments" in mapped:
            return False

        # Depreciation should not relate to discontinued ops or regulatory assets
        if "discontinued operations" in tag and "depreciation" in mapped:
            return False
        if "regulatory assets and liabilities" in tag and "depreciation" in mapped:
            return False

        # Dividend-related receivables should not map to interest
        if "dividends receivable" in tag and "interest" in mapped:
            return False

        # Withdrawals from contract holders should not be interest
        if "withdrawal from contract holders" in tag and "interest" in mapped:
            return False

        # Proceeds from secured credit lines ≠ short-term debt issuance
        if "secured lines of credit" in tag and "short term debt issued" in mapped:
            return False

        # Postemployment benefit payments ≠ tax payments
        if "postemployment benefits" in tag and "taxes paid" in mapped:
            return False

        # Specific: annuity/investment repayments ≠ debt issuance
        if "repayments of annuities and investment certificates" in tag and "debt issued" in mapped:
            return False

        # General: any repayment tag should not map to debt issuance
        if "repayment" in tag and "debt issued" in mapped:
            return False
        
        # Lease acquisition costs ≠ business acquisitions
        if "lease acquisition" in tag and "acquisition of business" in mapped:
            return False
        
        # Specific short-term debt repayments ≠ net short-term debt changes
        if "repayment" in tag and "short term debt, net" in mapped:
            return False

        # Deferred purchase price ≠ short-term debt issuance
        if "deferred purchase price" in tag and "short term debt issued" in mapped:
            return False



    # === Comprehensive Income Validation ===
    # Must map to a comprehensive income context
    if stmt == "comprehensive income":
        if "comprehensive income" not in mapped:
            return False

    # === Equity Statement Validation ===
    # Must map to an equity-related section
    if stmt == "equity statement":
        if not ("equity statement" in mapped or "shareholders' equity" in mapped):
            return False
        
        # Equity Statement: total equity incl. NCI ≠ redeemable noncontrolling interest
        if "stockholders equity including portion attributable to noncontrolling" in tag and "redeemable noncontrolling interest" in mapped:
            return False

    # === Income Statement Validation ===
    # Lease income should relate to interest or lease concepts
    if stmt == "income statement":
        if "lease income" in tag and "interest" not in mapped and "lease" not in mapped:
            return False

        # Income from continuing ops should not be tied to minority interest
        if "income loss from continuing operations before income taxes" in tag and "minority interest" in mapped:
            return False

        # Interest/dividend income should not map to preferred dividends
        if "interest and dividend income" in tag and "preferred dividends" in mapped:
            return False

        # Bank-owned life insurance ≠ net sales
        if "bank owned life insurance" in tag and "net sales" in mapped:
            return False

        # Interest expense should not map to any kind of interest income
        if "interest expense" in tag and "interest income" in mapped:
            return False

        # Nonrecurring income should not be classified as investment income
        if "nonrecurring income" in tag and "investment income" in mapped:
            return False

        # Asset-related income ≠ investment income
        if "asset related income" in tag and "investment income" in mapped:
            return False

        # Noncontrolling interest should not be mapped to interest expense
        if "noncontrolling interest" in tag and "interest expense" in mapped:
            return False
        
        # Consumer loan interest income ≠ investment income
        if ("interest and fee income" in tag and "loan" in tag and
            "investment income" in mapped):
            return False
        
        # Production taxes should not be classified under ambiguous 'other' expenses
        if "production tax" in tag and "other operating expenses" in mapped:
            return False
        
        # Variable lease income ≠ net income
        if "lease income" in tag and "net income" in mapped:
            return False


    return True

with open(OFSS_JSON, "r") as f:
    nested_ofss = json.load(f)

flat_ofss = flatten_ofss(nested_ofss)
ofss_names = list(flat_ofss.keys())
ofss_ids = [flat_ofss[name] for name in ofss_names]
ofss_embeddings = model.encode([s.lower() for s in ofss_names], convert_to_tensor=True)

# === STATEMENT NORMALIZATION ===
# STATEMENT_NORMALIZATION = {
#     "Equity Statement": "Balance Sheet",
#     "Comprehensive Income": "Income Statement"
# }

def clean_path(raw):
    tokens = raw.split(" > ")
    tokens = [t for t in tokens if not t.lower().endswith("abstract")]
    tokens = [t.replace("us-gaap_", "").replace("_", " ") for t in tokens]
    return " > ".join(tokens)

# === MATCH FUNCTION ===
def try_match(query_str, stmt, balance, tag):
    query_emb = model.encode(query_str.lower(), convert_to_tensor=True)
    allowed_indices = [
        i for i, name in enumerate(ofss_names)
        if normalize_stmt_prefix(name.lower().split("/")[0]) == normalize_stmt_prefix(stmt)
    ]
    if not allowed_indices:
        return None, None, None

    filtered_embeddings = ofss_embeddings[allowed_indices]
    cos_scores = util.cos_sim(query_emb, filtered_embeddings)[0]

    # Rank all candidates by similarity
    ranked_candidates = sorted(
        zip(allowed_indices, cos_scores),
        key=lambda x: x[1],
        reverse=True
    )

    # Check each candidate explicitly for accounting validity
    for candidate_idx, candidate_score in ranked_candidates:
        if candidate_score < SIMILARITY_THRESHOLD:
            break  # All remaining candidates will be below threshold

        mapped_name = ofss_names[candidate_idx].lower().replace("/", " > ")

        # Explicit validation check
        if is_valid_mapping(stmt, balance, tag, mapped_name):
            return candidate_idx, candidate_score.item(), mapped_name

    # If no valid candidates found
    return None, None, None


# === PROCESS CSV ===
output_rows = []
with open(INPUT_CSV, newline="") as f:
    reader = csv.DictReader(f)
    for row in tqdm(reader, desc="Processing rows"):
        stmt = row["statement_type"].strip()
        # stmt = STATEMENT_NORMALIZATION.get(stmt, stmt)

        tag = row["tag"]
        balance = row["balance"]
        desc = row.get("description", "").strip()
        path_raw = row["taxonomy_hierarchy"].strip()

        row["ofss_id"] = ""
        row["ofss_flattened_name"] = ""
        row["map_approach"] = ""

        if not stmt or not path_raw:
            print(f"{tag} | [SKIPPED: missing statement or path]")
            output_rows.append(row)
            continue

        has_match = False

        # --- Approach 1: Description only ---
        if desc:
            query1 = f"{stmt} > {desc}"
            idx, score, name = try_match(query1, stmt, balance, tag)
            if idx is not None:
                has_match = True

                row["ofss_id"] = str(ofss_ids[idx])
                row["ofss_flattened_name"] = name.lower().replace("/", " > ")
                row["map_approach"] = "1"
                row['score'] = score
                print(f"{tag} | [mapped.1 ({score:.2f})] {query1} → {row['ofss_flattened_name']}")
            else:
                print(f"{tag} | [unmapped] {query1}")
        else:
            print(f"{tag} | [unmapped: no description]")

        # --- Approach 2: Taxonomy path ---
        if has_match == False:
            path = clean_path(path_raw)
            query2 = f"{stmt} > {path}"
            idx, score, name = try_match(query2, stmt, balance, tag)
            if idx is not None:
                row["ofss_id"] = str(ofss_ids[idx])
                row["ofss_flattened_name"] = name.lower().replace("/", " > ")
                row["map_approach"] = "2"
                row['score'] = score
                print(f"{tag} | [mapped.2 ({score:.2f})] {query2} → {row['ofss_flattened_name']}")

        output_rows.append(row)

# === SAVE TO CSV ===
fieldnames = list(output_rows[0].keys())
with open(OUTPUT_CSV, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(output_rows)


In [None]:
# import pandas as pd

# # Load existing CSV
# df = pd.read_csv(OUTPUT_CSV)




# # Identify invalid mappings
# invalid_mask = ~df.apply(is_valid_mapping, axis=1)
# invalid_entries = df[invalid_mask]

# # Display number of invalid entries
# print(f"Found {len(invalid_entries)} invalid entries.")

# # Clear mappings for invalid entries to allow manual review or reprocessing
# df.loc[invalid_mask, ["ofss_id", "ofss_flattened_name", "map_approach", "score"]] = ""

# # Save the corrected CSV
# df.to_csv("data/with_ofss_ids_corrected.csv", index=False)

# # Show sample of corrections made for verification
# invalid_entries.sample(min(10, len(invalid_entries)), random_state=42)[['tag', 'statement_type', 'description', 'ofss_flattened_name', 'score']]