In [None]:
import os
import re
import csv
import xml.etree.ElementTree as ET
from collections import defaultdict

# https://www.fasb.org/page/detail?pageId=/projects/FASB-Taxonomies/2025-gaap-financial-reporting-taxonomy.html
# 2025 Taxonomy downloads: https://xbrl.fasb.org/us-gaap/2025/us-gaap-2025.zip

# === CONFIG ===
TAXONOMY_DIR = "data/us-gaap-2025/"
ELTS_XSD = os.path.join(TAXONOMY_DIR, "elts", "us-gaap-2025.xsd")
STM_DIR = os.path.join(TAXONOMY_DIR, "stm")

OUTPUT_PATH = "data/WITH_TAXONOMY_HIERARCHY_us_gaap_2025_with_all_statements_and_hierarchy.csv"

FILENAME_STATEMENT_MAP = {
    "scf": "Cash Flow Statement",
    "soi": "Income Statement",
    "sfp": "Balance Sheet",
    "sheci": "Equity Statement",
    "soc": "Comprehensive Income"
}

XBRLI_NS = "http://www.xbrl.org/2003/instance"
BALANCE_KEY = f"{{{XBRLI_NS}}}balance"
PERIOD_TYPE_KEY = f"{{{XBRLI_NS}}}periodType"


def generate_description(tag_name):
    return re.sub(
        r'(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])',
        ' ',
        tag_name
    ).lower()


# === STEP 1: TAG METADATA ===
tag_metadata = {}
tree = ET.parse(ELTS_XSD)
root = tree.getroot()
for el in root.findall(".//{http://www.w3.org/2001/XMLSchema}element"):
    name = el.attrib.get("name")
    if not name or el.attrib.get("abstract") == "true":
        continue
    tag_metadata[name] = {
        "balance": el.attrib.get(BALANCE_KEY, ""),
        "period_type": el.attrib.get(PERIOD_TYPE_KEY, ""),
        "statements": set()
    }

# === STEP 2: STATEMENT TYPE MAPPING ===
for file in os.listdir(STM_DIR):
    if not file.endswith(".xml") or "-pre-" not in file:
        continue
    parts = file.split("-")
    if len(parts) < 4:
        continue
    stmt_key = parts[3]
    inferred_statement = FILENAME_STATEMENT_MAP.get(stmt_key)
    if not inferred_statement:
        continue

    tree = ET.parse(os.path.join(STM_DIR, file))
    root = tree.getroot()
    for loc in root.findall(".//{http://www.xbrl.org/2003/linkbase}loc"):
        href = loc.attrib.get("{http://www.w3.org/1999/xlink}href", "")
        tag = href.split("#")[-1]
        if tag.startswith("us-gaap_"):
            tag = tag.replace("us-gaap_", "")
        if tag in tag_metadata:
            tag_metadata[tag]["statements"].add(inferred_statement)

# === STEP 3: BUILD BASE OUTPUT ===
final_output = []
for tag, meta in tag_metadata.items():
    for stmt in meta["statements"]:
        final_output.append({
            "tag": tag,
            "statement_type": stmt,
            "balance": meta["balance"],
            "period_type": meta["period_type"],
            "description": generate_description(tag),
            "subcategory_path": ""
        })

# === STEP 4: EXTRACT TAXONOMY HIERARCHY PATHS ===
def extract_tag_hierarchy_paths(pre_files):
    tag_paths = defaultdict(set)

    for path in pre_files:
        try:
            tree = ET.parse(path)
            root = tree.getroot()

            loc_map = {}
            arcs_by_from = defaultdict(list)

            for loc in root.findall(".//{http://www.xbrl.org/2003/linkbase}loc"):
                loc_id = loc.attrib.get("{http://www.w3.org/1999/xlink}label")
                href = loc.attrib.get("{http://www.w3.org/1999/xlink}href", "")
                if loc_id and "us-gaap_" in href:
                    tag = href.split("#")[-1]
                    loc_map[loc_id] = tag

            for arc in root.findall(".//{http://www.xbrl.org/2003/linkbase}presentationArc"):
                from_id = arc.attrib.get("{http://www.w3.org/1999/xlink}from")
                to_id = arc.attrib.get("{http://www.w3.org/1999/xlink}to")
                arcs_by_from[from_id].append(to_id)

            all_from = set(arcs_by_from.keys())
            all_to = {to for children in arcs_by_from.values() for to in children}
            root_ids = all_from - all_to

            def walk(current_id, path):
                tag = loc_map.get(current_id)
                current_path = path + [tag] if tag else path
                for child_id in arcs_by_from.get(current_id, []):
                    child_tag = loc_map.get(child_id)
                    if child_tag and child_tag.startswith("us-gaap_"):
                        clean_tag = child_tag.replace("us-gaap_", "")
                        tag_paths[clean_tag].add(" > ".join(current_path + [clean_tag]))
                    walk(child_id, current_path)

            for root_id in root_ids:
                walk(root_id, [])

        except ET.ParseError:
            continue

    return tag_paths


# Gather all presentation files (stm + dis + others)
pre_files = []
for root, _, files in os.walk(TAXONOMY_DIR):
    for file in files:
        if "-pre-" in file and file.endswith(".xml"):
            pre_files.append(os.path.join(root, file))

tag_hierarchy = extract_tag_hierarchy_paths(pre_files)

# === STEP 5: ATTACH HIERARCHY TO TAGS ===
ALL_STATEMENTS = {
    "Cash Flow Statement",
    "Income Statement",
    "Balance Sheet",
    "Equity Statement",
    "Comprehensive Income"
}

final_filtered = []
for row in final_output:
    if row["statement_type"] in ALL_STATEMENTS:
        tag = row["tag"]
        row["taxonomy_hierarchy"] = (
            list(tag_hierarchy.get(tag))[0] if tag in tag_hierarchy else ""
        )
        final_filtered.append(row)

# === STEP 6: WRITE OUTPUT CSV ===
# os.makedirs("output", exist_ok=True)
with open(OUTPUT_PATH, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=[
        "tag", "statement_type", "balance", "period_type",
        "description", "subcategory_path", "taxonomy_hierarchy"
    ])
    writer.writeheader()
    writer.writerows(final_filtered)


# Append OFSS IDs

In [None]:
import csv
import json
import torch
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
from rapidfuzz.fuzz import token_sort_ratio

# === CONFIG ===
MODEL_NAME = "BAAI/bge-large-en-v1.5"
INPUT_CSV = "data/WITH_TAXONOMY_HIERARCHY_us_gaap_2025_with_all_statements_and_hierarchy.csv"
OFSS_JSON = "../shared/open_financial_statement_schema.json"
OUTPUT_CSV = "data/with_ofss_ids.csv"
SIMILARITY_THRESHOLD = 0.7

# === DEVICE SETUP ===
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

model = SentenceTransformer(MODEL_NAME, device=device)

# === FLATTEN NESTED OFSS MAP ===
def flatten_ofss(d, parent_key=""):
    flat = {}
    for k, v in d.items():
        new_key = f"{parent_key} > {k}" if parent_key else k
        if isinstance(v, dict):
            flat.update(flatten_ofss(v, new_key))
        else:
            flat[new_key] = v
    return flat

def normalize_stmt_prefix(stmt):
    return stmt.lower().replace(" statement", "")

def clean_path(raw):
    tokens = raw.split(" > ")
    tokens = [t for t in tokens if not t.lower().endswith("abstract")]
    tokens = [t.replace("us-gaap_", "").replace("_", " ") for t in tokens]
    return " > ".join(tokens)

# === LOAD OFSS ===
with open(OFSS_JSON, "r") as f:
    nested_ofss = json.load(f)

flat_ofss = flatten_ofss(nested_ofss)
ofss_names = list(flat_ofss.keys())
ofss_ids = [flat_ofss[name] for name in ofss_names]
ofss_embeddings = model.encode([s.lower() for s in ofss_names], convert_to_tensor=True)

# === MATCH FUNCTION ===
def try_match_multi(stmt, desc, tag, path_raw):
    """
    Attempts multiple query formats and returns the best match above threshold.
    """
    queries = []
    if desc:
        queries.append(f"{stmt} > {desc}")
    if path_raw:
        path = clean_path(path_raw)
        queries.append(f"{stmt} > {path}")
    if tag:
        queries.append(f"{stmt} > {tag}")
    if desc and tag:
        queries.append(f"{stmt} > {desc} > {tag}")

    best = None
    best_score = -1

    for query_str in queries:
        query_emb = model.encode(query_str.lower(), convert_to_tensor=True)
        cos_scores = util.cos_sim(query_emb, ofss_embeddings)[0]

        for i, score in enumerate(cos_scores):
            name = ofss_names[i]
            if normalize_stmt_prefix(name.split(" > ")[0]) != normalize_stmt_prefix(stmt):
                continue

            fuzzy_score = token_sort_ratio(query_str.lower(), name.lower()) / 100
            combined_score = (0.85 * score.item()) + (0.15 * fuzzy_score)

            if combined_score > SIMILARITY_THRESHOLD and combined_score > best_score:
                best = (i, score.item(), name, query_str)
                best_score = combined_score

    if best:
        idx, raw_score, name, query = best
        return idx, raw_score, name, query
    return None, None, None, None

# === PROCESS CSV ===
output_rows = []
with open(INPUT_CSV, newline="") as f:
    reader = csv.DictReader(f)
    for row in tqdm(reader, desc="Processing rows"):
        stmt = row["statement_type"].strip()
        tag = row["tag"]
        balance = row["balance"]
        desc = row.get("description", "").strip()
        path_raw = row["taxonomy_hierarchy"].strip()

        row["ofss_id"] = ""
        row["ofss_flattened_name"] = ""
        row["map_approach"] = ""

        if not stmt or not path_raw:
            print(f"{tag} | [SKIPPED: missing statement or path]")
            output_rows.append(row)
            continue

        idx, score, name, query_used = try_match_multi(stmt, desc, tag, path_raw)
        if idx is not None:
            row["ofss_id"] = str(ofss_ids[idx])
            row["ofss_flattened_name"] = name.lower()
            row["map_approach"] = "multi"
            row["score"] = score
            print(f"{tag} | [mapped.multi ({score:.2f})] {query_used} â†’ {row['ofss_flattened_name']}")
        else:
            print(f"{tag} | [unmapped.multi]")

        output_rows.append(row)

# === SAVE TO CSV ===
fieldnames = list(output_rows[0].keys())
with open(OUTPUT_CSV, "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(output_rows)
