#### Stage-1 importer (ONLY drug table)
drug ✅ (root table; everything else hangs from this)

In [23]:
# =========================
# Stage 1: Load drug table (FULL, reliable)
# =========================
import time
from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# --------- EDIT THESE ----------
XML_PATH = r"db/drugbank.xml"

MYSQL_HOST="localhost"
MYSQL_PORT=3306
MYSQL_USER="root"
MYSQL_PASSWORD=""
MYSQL_DB="drugbank"

COMMIT_EVERY = 1000
SHOW_EVERY = 10000

# Set True for clean reload (deletes all rows from drug first)
WIPE_DRUG_TABLE = False
# -------------------------------

def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
        password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=False
    )

def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def detect_namespace(xml_path: str):
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem):
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def to_float(s):
    if s is None:
        return None
    try:
        return float(s)
    except:
        return None

def bool_from_text(s):
    if s is None:
        return None
    v = str(s).strip().lower()
    if v in ("true", "1", "yes"): return True
    if v in ("false", "0", "no"): return False
    return None

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def count_total_drugs(xml_path: str) -> int:
    total = 0
    p = tqdm(total=None, unit="drug", dynamic_ncols=True, desc="Pre-pass: counting <drug>")
    try:
        ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, e in ctx:
            if strip_ns(e.tag) == "drug":
                total += 1
                p.update(1)
                cleanup_lxml(e)
    finally:
        try: p.close()
        except: pass
    return total

def ensure_unique_index_on_drug(conn):
    """
    Ensures drug.primary_drugbank_id has a UNIQUE index.
    This is REQUIRED for ON DUPLICATE KEY UPDATE to work correctly.
    """
    cur = conn.cursor()
    cur.execute("SHOW INDEX FROM drug WHERE Column_name='primary_drugbank_id'")
    rows = cur.fetchall()
    has_unique = any(int(r[1]) == 0 for r in rows)  # Non_unique == 0
    if not has_unique:
        print("[INFO] Adding UNIQUE index on drug(primary_drugbank_id)...")
        cur.execute("ALTER TABLE drug ADD UNIQUE KEY uq_drug_primary_id (primary_drugbank_id)")
        conn.commit()
    cur.close()

DRUG_UPSERT_SQL = """
INSERT INTO drug (
  primary_drugbank_id, name, description, cas_number, unii,
  average_mass, monoisotopic_mass, state,
  synthesis_reference, indication, pharmacodynamics, mechanism_of_action,
  toxicity, metabolism, absorption, half_life, protein_binding,
  route_of_elimination, volume_of_distribution, clearance,
  fda_label, msds
)
VALUES (%s,%s,%s,%s,%s,
        %s,%s,%s,
        %s,%s,%s,%s,
        %s,%s,%s,%s,%s,
        %s,%s,%s,
        %s,%s)
ON DUPLICATE KEY UPDATE
  name = VALUES(name),
  description = VALUES(description),
  cas_number = VALUES(cas_number),
  unii = VALUES(unii),
  average_mass = VALUES(average_mass),
  monoisotopic_mass = VALUES(monoisotopic_mass),
  state = VALUES(state),
  synthesis_reference = VALUES(synthesis_reference),
  indication = VALUES(indication),
  pharmacodynamics = VALUES(pharmacodynamics),
  mechanism_of_action = VALUES(mechanism_of_action),
  toxicity = VALUES(toxicity),
  metabolism = VALUES(metabolism),
  absorption = VALUES(absorption),
  half_life = VALUES(half_life),
  protein_binding = VALUES(protein_binding),
  route_of_elimination = VALUES(route_of_elimination),
  volume_of_distribution = VALUES(volume_of_distribution),
  clearance = VALUES(clearance),
  fda_label = VALUES(fda_label),
  msds = VALUES(msds)
"""

def import_stage1_drug(xml_path: str):
    ns = detect_namespace(xml_path)
    nstag = nstag_factory(ns)
    print("Detected namespace:", ns)

    total = count_total_drugs(xml_path)
    print(f"Total <drug> in XML: {total:,}")

    conn = get_conn()
    cur = conn.cursor()

    # Ensure uniqueness so upsert behaves correctly
    ensure_unique_index_on_drug(conn)

    if WIPE_DRUG_TABLE:
        print("[CAUTION] Deleting all rows from drug...")
        cur.execute("DELETE FROM drug")
        conn.commit()

    processed = 0
    errors = 0

    inserted = 0
    updated = 0
    unchanged = 0

    t0 = time.time()
    pbar = tqdm(total=total, unit="drug", dynamic_ncols=True, desc="Stage 1: drug")

    try:
        ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, elem in ctx:
            if strip_ns(elem.tag) != "drug":
                continue

            try:
                # ---- choose primary_drugbank_id deterministically ----
                ids = elem.findall(nstag("drugbank-id"))
                primary_id = None
                first_id = None
                for ide in ids:
                    did = text(ide)
                    if not did:
                        continue
                    if first_id is None:
                        first_id = did
                    if bool_from_text(ide.get("primary")):
                        primary_id = did
                        break
                primary_id = primary_id or first_id
                if not primary_id:
                    errors += 1
                    processed += 1
                    pbar.update(1)
                    cleanup_lxml(elem)
                    continue

                def f(tag):
                    return text(elem.find(nstag(tag)))

                row = (
                    primary_id,
                    f("name") or "",     # NOT NULL
                    f("description"),
                    f("cas-number"),
                    f("unii"),
                    to_float(f("average-mass")),
                    to_float(f("monoisotopic-mass")),
                    f("state"),
                    f("synthesis-reference"),
                    f("indication"),
                    f("pharmacodynamics"),
                    f("mechanism-of-action"),
                    f("toxicity"),
                    f("metabolism"),
                    f("absorption"),
                    f("half-life"),
                    f("protein-binding"),
                    f("route-of-elimination"),
                    f("volume-of-distribution"),
                    f("clearance"),
                    f("fda-label"),
                    f("msds"),
                )

                cur.execute(DRUG_UPSERT_SQL, row)

                # classify what happened
                cur.execute("SELECT ROW_COUNT()")
                rc = int(cur.fetchone()[0])
                if rc == 1:
                    inserted += 1
                elif rc == 2:
                    updated += 1
                else:
                    unchanged += 1

                processed += 1
                pbar.update(1)

                if processed % COMMIT_EVERY == 0:
                    conn.commit()

                if processed % SHOW_EVERY == 0:
                    dt = time.time() - t0
                    rate = processed / dt if dt > 0 else 0.0
                    print(f"[Stage1] {processed:,}/{total:,} rate={rate:.2f}/sec "
                          f"ins={inserted:,} upd={updated:,} same={unchanged:,} err={errors:,}")

            except Exception as e:
                errors += 1
                conn.rollback()
                print("ERROR on drug:", e)

            finally:
                cleanup_lxml(elem)

        conn.commit()

    finally:
        try: pbar.close()
        except: pass
        cur.close()
        conn.close()

    dt = time.time() - t0
    rate = processed / dt if dt > 0 else 0.0
    print("\nDONE Stage 1 (drug)")
    print(f"Processed:  {processed:,} (expected {total:,})")
    print(f"Inserted:   {inserted:,}")
    print(f"Updated:    {updated:,}")
    print(f"Unchanged:  {unchanged:,}")
    print(f"Errors:     {errors:,}")
    print(f"Elapsed:    {dt:.1f}s  Rate: {rate:.2f}/sec")

    # sanity check from DB
    conn2 = get_conn()
    cur2 = conn2.cursor()
    cur2.execute("SELECT COUNT(*), COUNT(DISTINCT primary_drugbank_id) FROM drug")
    total_rows, distinct_ids = cur2.fetchone()
    cur2.close(); conn2.close()
    print(f"DB check: drug rows={int(total_rows):,}, distinct primary_drugbank_id={int(distinct_ids):,}")

# RUN
import_stage1_drug(XML_PATH)

Detected namespace: http://www.drugbank.ca


Pre-pass: counting <drug>: 0drug [00:00, ?drug/s]

Total <drug> in XML: 73,687


Stage 1: drug:   0%|                                                                       | 0/73687 [00:00<?,…

[Stage1] 10,000/73,687 rate=897.33/sec ins=342 upd=62 same=9,596 err=0
[Stage1] 20,000/73,687 rate=705.49/sec ins=720 upd=207 same=19,073 err=0
[Stage1] 30,000/73,687 rate=536.23/sec ins=1,851 upd=617 same=27,532 err=0
[Stage1] 40,000/73,687 rate=614.82/sec ins=2,799 upd=745 same=36,456 err=0
[Stage1] 50,000/73,687 rate=674.62/sec ins=3,759 upd=872 same=45,369 err=0
[Stage1] 60,000/73,687 rate=706.78/sec ins=5,261 upd=966 same=53,773 err=0
[Stage1] 70,000/73,687 rate=543.66/sec ins=13,743 upd=1,039 same=55,218 err=0

DONE Stage 1 (drug)
Processed:  73,687 (expected 73,687)
Inserted:   17,430
Updated:    1,039
Unchanged:  55,218
Errors:     0
Elapsed:    133.7s  Rate: 551.12/sec
DB check: drug rows=17,430, distinct primary_drugbank_id=17,430


#### Stage 2 — Direct children of drug (FK → drug)
These only need drug_pk:
drugbank_id_map
drug_group
drug_classification (one row per drug)
drug_classification_alternative_parent
drug_classification_substituent
drug_synonym
drug_salt
product
packager
manufacturer
drug_category
drug_affected_organism
drug_dosage
drug_ahfs_code
drug_pdb_entry
drug_atc_code
drug_atc_level (depends on drug_atc_code via FK (drug_pk, atc_code))
drug_price
drug_patent
drug_food_interaction
drug_interaction
drug_external_identifier
drug_external_link
drug_property
reaction (depends on drug_pk)

Below is a Stage-2 only notebook-friendly solution that does exactly what you asked:
One pass per table (it re-reads / re-scans the XML for each table)
One function per Stage-2 table
Shows progress bar for each table pass
At the end, prints a final summary table:
how many rows were attempted / affected during that pass
the current row count in the DB for that table (SELECT COUNT(*))
✅ Assumption: Stage-1 (drug) is already loaded, so we can look up drug_pk from primary_drugbank_id.

## Cautious it will delete all the data from Stage - 2

In [22]:
import mysql.connector

MYSQL_HOST="localhost"
MYSQL_PORT=3306
MYSQL_USER="root"
MYSQL_PASSWORD=""
MYSQL_DB="drugbank"

def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
        password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=False
    )

def table_exists(cur, table: str) -> bool:
    cur.execute(
        "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema=%s AND table_name=%s",
        (MYSQL_DB, table),
    )
    return int(cur.fetchone()[0]) == 1

def wipe_table(cur, table: str) -> str:
    """
    Try TRUNCATE (fast). If FK blocks it, fall back to DELETE (safe).
    """
    try:
        cur.execute(f"TRUNCATE TABLE `{table}`")
        return "TRUNCATE"
    except mysql.connector.Error as e:
        cur.execute(f"DELETE FROM `{table}`")
        return f"DELETE (fallback: {e.errno})"

# ---------- FK-safe wipe order (children -> parents) ----------
WIPE_ORDER = [
    # Stage 9
    "drug_raw",

    # Stage 8 (children of polypeptide)
    "polypeptide_external_identifier",
    "polypeptide_synonym",

    # Stage 7 (children of interactant, and interactant itself)
    "interactant_polypeptide",
    "interactant_action",
    "interactant_reference",
    "interactant",

    # Stage 6 (children of reaction)
    "reaction_enzyme",
    "reaction_element",

    # Stage 5 (pathway joins and members)
    "pathway_drug_member",
    "pathway_enzyme_member",
    "drug_pathway",
    "pathway",

    # Stage 4
    "drug_reference",
    "reference_item",

    # Stage 3
    "drug_salt_drugbank_id",

    # Stage 2 (children of drug)
    "drug_atc_level",
    "drug_atc_code",
    "drug_property",
    "drug_external_link",
    "drug_external_identifier",
    "drug_interaction",
    "drug_food_interaction",
    "drug_patent",
    "drug_price",
    "reaction",  # parent of reaction_* tables (already deleted above)
    "drug_pdb_entry",
    "drug_ahfs_code",
    "drug_dosage",
    "drug_affected_organism",
    "drug_category",
    "manufacturer",
    "packager",
    "product",
    "drug_salt",  # parent of drug_salt_drugbank_id (already deleted above)
    "drug_synonym",
    "drug_classification_substituent",
    "drug_classification_alternative_parent",
    "drug_classification",
    "drug_group",
    "drugbank_id_map",

    # Polypeptide base table (after its children + links)
    "polypeptide",

    # Stage 1 (root)
    "drug",
]

print("⚠️ This will wipe ALL DrugBank tables listed in WIPE_ORDER.")
confirm = input('Type WIPE to continue: ').strip()
if confirm != "WIPE":
    print("Cancelled.")
else:
    conn = get_conn()
    cur = conn.cursor()
    results = []
    try:
        for t in WIPE_ORDER:
            if not table_exists(cur, t):
                results.append((t, "SKIP (missing)"))
                continue
            method = wipe_table(cur, t)
            results.append((t, method))
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        cur.close()
        conn.close()

    print("\nWipe results:")
    for t, m in results:
        print(f"{t:35} -> {m}")


⚠️ This will wipe ALL DrugBank tables listed in WIPE_ORDER.


Type WIPE to continue:  WIPE



Wipe results:
drug_raw                            -> TRUNCATE
polypeptide_external_identifier     -> TRUNCATE
polypeptide_synonym                 -> TRUNCATE
interactant_polypeptide             -> TRUNCATE
interactant_action                  -> TRUNCATE
interactant_reference               -> TRUNCATE
interactant                         -> DELETE (fallback: 1701)
reaction_enzyme                     -> TRUNCATE
reaction_element                    -> TRUNCATE
pathway_drug_member                 -> TRUNCATE
pathway_enzyme_member               -> TRUNCATE
drug_pathway                        -> TRUNCATE
pathway                             -> DELETE (fallback: 1701)
drug_reference                      -> TRUNCATE
reference_item                      -> DELETE (fallback: 1701)
drug_salt_drugbank_id               -> TRUNCATE
drug_atc_level                      -> TRUNCATE
drug_atc_code                       -> DELETE (fallback: 1701)
drug_property                       -> TRUNCATE
drug_external

<small>

### Stage 2 — Direct children of `drug` (FK → drug)

This stage loads the **tables that depend only on `drug_pk`** (foreign key to the `drug` table).

**Prerequisite:** Stage 1 must be completed (`drug` table populated), so every drug has a valid `drug_pk`.

---

#### Tables loaded in Stage 2

- `drugbank_id_map`
- `drug_group`
- `drug_classification`
- `drug_classification_alternative_parent`
- `drug_classification_substituent`
- `drug_synonym`
- `drug_salt`
- `product`
- `packager`
- `manufacturer`
- `drug_category`
- `drug_affected_organism`
- `drug_dosage`
- `drug_ahfs_code`
- `drug_pdb_entry`
- `drug_atc_code`
- `drug_atc_level` *(must run after `drug_atc_code`)*
- `drug_price`
- `drug_patent`
- `drug_food_interaction`
- `drug_interaction`
- `drug_external_identifier`
- `drug_external_link`
- `drug_property`
- `reaction`

---

#### Execution strategy (one XML scan per table)
This Stage 2 workflow runs **one full XML pass per table**, meaning:

- The XML file is re-read multiple times (once for each table).
- Each table gets its own:
  - progress bar
  - periodic status logs
  - commit batching

This is **slower** than a single-pass multi-table insert, but it is useful for debugging and isolating issues.

---

#### Dependency / sequencing
Inside Stage 2, the most important ordering rule is:

- **`drug_atc_code` must be loaded before `drug_atc_level`**

Everything else depends only on `drug_pk` and can be loaded independently.

---

#### Re-run safety and “skip if already loaded”
Some tables are naturally **idempotent** because the loader uses `INSERT IGNORE` or `ON DUPLICATE KEY UPDATE`
(e.g., many mapping tables).

Other tables can **duplicate on reruns** because they have no natural unique key or use auto-increment IDs
(e.g., `drug_salt`, `product`, `drug_price`, `drug_patent`, `drug_property`, `reaction`).

To avoid accidental duplicates:
- the Stage 2 runner can **skip a table** if it already has rows, OR
- explicitly **wipe (DELETE)** and reload selected tables.

---

#### Output / reporting
For each table pass, the runner reports:
- progress bar over total drugs in DB
- attempted vs affected rows
- per-table time spent

At the end, it prints a consolidated **summary table**:
- Table name
- Status (LOADED / SKIPPED)
- Attempted statements
- Affected rows (rowcount)
- Rows currently in DB
- Time taken

</small>


In [25]:
# =====================================================================================
# Stage 2 loader (ALL Stage-2 tables) - corrected + pre-pass count + robust ID handling
# One XML pass per table, progress bar based on total drugs in XML.
# =====================================================================================

import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, Callable, Tuple, List

from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# ---------------------------
# CONFIG (EDIT THESE)
# ---------------------------
XML_PATH = r"db/drugbank.xml"

MYSQL_HOST = "localhost"
MYSQL_PORT = 3306
MYSQL_USER = "root"
MYSQL_PASSWORD = ""
MYSQL_DB = "drugbank"

COMMIT_EVERY = 1000
SHOW_EVERY = 10000

# skip table if already has rows (prevents duplicates for append tables)
SKIP_IF_NONEMPTY = True

# force reload (wipe + reload) only these tables
FORCE_RELOAD_TABLES = set()  # e.g. {"drug_property", "reaction"}

# wipe before load these tables (DELETE FROM)
WIPE_BEFORE_LOAD_TABLES = set()  # e.g. {"drug_property", "reaction"}
# ---------------------------

def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
        password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=False
    )

def count_table(conn, table: str) -> int:
    cur = conn.cursor()
    cur.execute(f"SELECT COUNT(*) FROM `{table}`")
    n = int(cur.fetchone()[0])
    cur.close()
    return n

def wipe_table_delete(table: str) -> None:
    conn = get_conn()
    cur = conn.cursor()
    cur.execute(f"DELETE FROM `{table}`")
    conn.commit()
    cur.close()
    conn.close()

# ---------------------------
# XML helpers
# ---------------------------
def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def detect_namespace(xml_path: str) -> Optional[str]:
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns: Optional[str]):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem) -> Optional[str]:
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def to_float(s: Optional[str]) -> Optional[float]:
    if s is None:
        return None
    try:
        return float(s)
    except:
        return None

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def count_total_drugs_with_progress(xml_path: str) -> int:
    total = 0
    p = tqdm(total=None, unit="drug", dynamic_ncols=True, desc="Pre-pass: counting <drug>")
    try:
        ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, e in ctx:
            if strip_ns(e.tag) == "drug":
                total += 1
                p.update(1)
                cleanup_lxml(e)
    finally:
        try: p.close()
        except: pass
    return total

def pick_primary_id(drug_elem, nstag) -> Optional[str]:
    """
    Robust primary id selection:
    - if @primary exists and not in {false,0,no,n,''} => choose it
    - else first non-empty drugbank-id
    """
    ids = drug_elem.findall(nstag("drugbank-id"))
    first = None
    for ide in ids:
        did = text(ide)
        if not did:
            continue
        if first is None:
            first = did
        prim = ide.get("primary")
        if prim is not None:
            v = prim.strip().lower()
            if v not in ("false", "0", "no", "n", ""):
                return did
    return first

def load_drug_pk_map(conn) -> Dict[str, int]:
    cur = conn.cursor()
    cur.execute("SELECT primary_drugbank_id, drug_pk FROM drug")
    mp = {row[0]: int(row[1]) for row in cur.fetchall()}
    cur.close()
    return mp

# ---------------------------
# Reporting
# ---------------------------
def print_table(rows: List[List[Any]], headers: List[str]) -> None:
    cols = len(headers)
    widths = [len(h) for h in headers]
    for r in rows:
        for i in range(cols):
            widths[i] = max(widths[i], len(str(r[i])))

    def sep(ch="-"):
        return "+".join(ch * (w + 2) for w in widths)

    def render(r):
        return "|".join(" " + str(r[i]).ljust(widths[i]) + " " for i in range(cols))

    print(sep("-"))
    print(render(headers))
    print(sep("="))
    for r in rows:
        print(render(r))
        print(sep("-"))

@dataclass
class PassResult:
    table: str
    status: str
    attempted: int
    affected: int
    db_rows: int
    missing_drug: int
    seconds: float

# ---------------------------
# Generic one-table pass runner (re-scans XML each time)
# handler(cur, drug_pk, drug_elem, nstag) -> (attempted, affected)
# ---------------------------
def run_table_pass(
    xml_path: str,
    table: str,
    handler: Callable[[Any, int, Any, Callable[[str], str]], Tuple[int, int]],
    ns: Optional[str],
    drug_pk_map: Dict[str, int],
    total_drugs_xml: int,
) -> PassResult:
    nstag = nstag_factory(ns)

    # skip / wipe logic
    conn_check = get_conn()
    current_rows = count_table(conn_check, table)
    conn_check.close()

    if table in FORCE_RELOAD_TABLES or table in WIPE_BEFORE_LOAD_TABLES:
        wipe_table_delete(table)
        current_rows = 0

    if SKIP_IF_NONEMPTY and current_rows > 0 and table not in FORCE_RELOAD_TABLES:
        return PassResult(table=table, status="SKIPPED (already has rows)", attempted=0, affected=0,
                          db_rows=current_rows, missing_drug=0, seconds=0.0)

    conn = get_conn()
    cur = conn.cursor()

    attempted_total = 0
    affected_total = 0
    missing_drug = 0
    processed = 0
    t0 = time.time()

    pbar = tqdm(total=total_drugs_xml, unit="drug", dynamic_ncols=True, desc=table)

    try:
        context = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, elem in context:
            if strip_ns(elem.tag) != "drug":
                continue

            primary_id = pick_primary_id(elem, nstag)
            if not primary_id:
                processed += 1
                pbar.update(1)
                cleanup_lxml(elem)
                continue

            drug_pk = drug_pk_map.get(primary_id)
            if drug_pk is None:
                missing_drug += 1
                processed += 1
                pbar.update(1)
                cleanup_lxml(elem)
                continue

            try:
                a, aff = handler(cur, drug_pk, elem, nstag)
                attempted_total += a
                affected_total += aff

                processed += 1
                pbar.update(1)

                if processed % COMMIT_EVERY == 0:
                    conn.commit()

                if SHOW_EVERY and processed % SHOW_EVERY == 0:
                    elapsed = time.time() - t0
                    rate = processed / elapsed if elapsed > 0 else 0.0
                    print(f"[{table}] drugs={processed:,}/{total_drugs_xml:,} attempted={attempted_total:,} "
                          f"affected={affected_total:,} missing_drug={missing_drug:,} rate={rate:.2f}/sec")

            except Exception as e:
                conn.rollback()
                print(f"[{table}] ERROR drug_pk={drug_pk} ({primary_id}): {e}")

            finally:
                cleanup_lxml(elem)

        conn.commit()

    finally:
        try: pbar.close()
        except: pass
        cur.close()
        conn.close()

    seconds = time.time() - t0
    conn2 = get_conn()
    db_rows = count_table(conn2, table)
    conn2.close()

    return PassResult(table=table, status="LOADED", attempted=attempted_total, affected=affected_total,
                      db_rows=db_rows, missing_drug=missing_drug, seconds=seconds)

# =====================================================================================
# Stage-2 handlers
# =====================================================================================

SQL_ID_MAP = """
INSERT INTO drugbank_id_map (drugbank_id, drug_pk, is_primary)
VALUES (%s,%s,%s)
ON DUPLICATE KEY UPDATE drug_pk=VALUES(drug_pk), is_primary=VALUES(is_primary)
"""
def h_drugbank_id_map(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    for ide in drug_elem.findall(nstag("drugbank-id")):
        did = text(ide)
        if not did:
            continue
        prim = ide.get("primary")
        is_primary = False
        if prim is not None:
            v = prim.strip().lower()
            is_primary = v not in ("false","0","no","n","")
        cur.execute(SQL_ID_MAP, (did, drug_pk, bool(is_primary)))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_GROUP = "INSERT IGNORE INTO drug_group (drug_pk, group_name) VALUES (%s,%s)"
def h_drug_group(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    groups = drug_elem.find(nstag("groups"))
    if groups is None:
        return 0, 0
    for g in groups.findall(nstag("group")):
        gv = text(g)
        if gv:
            cur.execute(SQL_GROUP, (drug_pk, gv))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_CLASS = """
INSERT INTO drug_classification
  (drug_pk, description, direct_parent, kingdom, superclass, class_name, subclass_name)
VALUES (%s,%s,%s,%s,%s,%s,%s)
ON DUPLICATE KEY UPDATE
  description=VALUES(description),
  direct_parent=VALUES(direct_parent),
  kingdom=VALUES(kingdom),
  superclass=VALUES(superclass),
  class_name=VALUES(class_name),
  subclass_name=VALUES(subclass_name)
"""
def h_drug_classification(cur, drug_pk, drug_elem, nstag):
    c = drug_elem.find(nstag("classification"))
    if c is None:
        return 0, 0
    def f(tag): return text(c.find(nstag(tag)))
    cur.execute(SQL_CLASS, (drug_pk, f("description"), f("direct-parent"), f("kingdom"),
                            f("superclass"), f("class"), f("subclass")))
    return 1, max(cur.rowcount, 0)

# These will stay 0 if XML lacks those sections (your XML does)
SQL_ALT_PARENT = "INSERT IGNORE INTO drug_classification_alternative_parent (drug_pk, value) VALUES (%s,%s)"
def h_alt_parent(cur, drug_pk, drug_elem, nstag):
    c = drug_elem.find(nstag("classification"))
    if c is None:
        return 0, 0
    aps = c.find(nstag("alternative-parents"))
    if aps is None:
        return 0, 0
    attempted = affected = 0
    for ap in aps.findall(nstag("alternative-parent")):
        v = text(ap)
        if v:
            cur.execute(SQL_ALT_PARENT, (drug_pk, v))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_SUB = "INSERT IGNORE INTO drug_classification_substituent (drug_pk, value) VALUES (%s,%s)"
def h_substituent(cur, drug_pk, drug_elem, nstag):
    c = drug_elem.find(nstag("classification"))
    if c is None:
        return 0, 0
    subs = c.find(nstag("substituents"))
    if subs is None:
        return 0, 0
    attempted = affected = 0
    for s in subs.findall(nstag("substituent")):
        v = text(s)
        if v:
            cur.execute(SQL_SUB, (drug_pk, v))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_SYNONYM = "INSERT IGNORE INTO drug_synonym (drug_pk, synonym, language, coder) VALUES (%s,%s,%s,%s)"
def h_drug_synonym(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    syns = drug_elem.find(nstag("synonyms"))
    if syns is None:
        return 0, 0
    for s in syns.findall(nstag("synonym")):
        sv = text(s)
        if sv:
            cur.execute(SQL_SYNONYM, (drug_pk, sv, s.get("language"), s.get("coder")))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_SALT = "INSERT INTO drug_salt (drug_pk, name, unii, cas_number, inchikey, average_mass, monoisotopic_mass) VALUES (%s,%s,%s,%s,%s,%s,%s)"
def h_drug_salt(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    salts = drug_elem.find(nstag("salts"))
    if salts is None:
        return 0, 0
    for salt in salts.findall(nstag("salt")):
        name = text(salt.find(nstag("name")))
        unii = text(salt.find(nstag("unii")))
        cas = text(salt.find(nstag("cas-number")))
        inchikey = text(salt.find(nstag("inchikey")))
        avg = to_float(text(salt.find(nstag("average-mass"))))
        mono = to_float(text(salt.find(nstag("monoisotopic-mass"))))
        cur.execute(SQL_SALT, (drug_pk, name, unii, cas, inchikey, avg, mono))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_PRODUCT = """
INSERT INTO product (
  drug_pk, name, labeller, ndc_id, ndc_product_code, dpd_id, ema_product_code,
  started_marketing_on, ended_marketing_on, dosage_form, strength, route,
  fda_application_number, generic, over_the_counter, approved, country, source
)
VALUES (%s,%s,%s,%s,%s,%s,%s,
        %s,%s,%s,%s,%s,
        %s,%s,%s,%s,%s,%s)
"""
def h_product(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    products = drug_elem.find(nstag("products"))
    if products is None:
        return 0, 0
    for p in products.findall(nstag("product")):
        def f(tag): return text(p.find(nstag(tag)))
        generic = (f("generic") or "").strip().lower() in ("true","1","yes")
        otc = (f("over-the-counter") or "").strip().lower() in ("true","1","yes")
        approved = (f("approved") or "").strip().lower() in ("true","1","yes")
        cur.execute(SQL_PRODUCT, (
            drug_pk, f("name"), f("labeller"), f("ndc-id"), f("ndc-product-code"), f("dpd-id"), f("ema-product-code"),
            f("started-marketing-on"), f("ended-marketing-on"), f("dosage-form"), f("strength"), f("route"),
            f("fda-application-number"), generic, otc, approved, f("country"), f("source")
        ))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_PACKAGER = """
INSERT INTO packager (drug_pk, name, url)
VALUES (%s,%s,%s)
ON DUPLICATE KEY UPDATE url=VALUES(url)
"""
def h_packager(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    packagers = drug_elem.find(nstag("packagers"))
    if packagers is None:
        return 0, 0
    for p in packagers.findall(nstag("packager")):
        name = text(p.find(nstag("name")))
        url = text(p.find(nstag("url")))
        if name:
            cur.execute(SQL_PACKAGER, (drug_pk, name, url))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_MANUFACTURER = """
INSERT INTO manufacturer (drug_pk, name, generic, url)
VALUES (%s,%s,%s,%s)
ON DUPLICATE KEY UPDATE generic=VALUES(generic), url=VALUES(url)
"""
def h_manufacturer(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    mfgs = drug_elem.find(nstag("manufacturers"))
    if mfgs is None:
        return 0, 0
    for m in mfgs.findall(nstag("manufacturer")):
        name = text(m)
        if not name:
            continue
        generic_attr = (m.get("generic") or "").strip().lower()
        generic = generic_attr in ("true","1","yes")
        url = m.get("url")
        cur.execute(SQL_MANUFACTURER, (drug_pk, name, generic, url))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_CATEGORY = "INSERT IGNORE INTO drug_category (drug_pk, category, mesh_id) VALUES (%s,%s,%s)"
def h_drug_category(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    cats = drug_elem.find(nstag("categories"))
    if cats is None:
        return 0, 0
    for c in cats.findall(nstag("category")):
        cat = text(c.find(nstag("category")))
        mesh = text(c.find(nstag("mesh-id")))
        if cat:
            cur.execute(SQL_CATEGORY, (drug_pk, cat, mesh))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_AO = "INSERT IGNORE INTO drug_affected_organism (drug_pk, organism) VALUES (%s,%s)"
def h_affected_organism(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    aos = drug_elem.find(nstag("affected-organisms"))
    if aos is None:
        return 0, 0
    for ao in aos.findall(nstag("affected-organism")):
        v = text(ao)
        if v:
            cur.execute(SQL_AO, (drug_pk, v))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_DOSAGE = "INSERT IGNORE INTO drug_dosage (drug_pk, form, route, strength) VALUES (%s,%s,%s,%s)"
def h_drug_dosage(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    dosages = drug_elem.find(nstag("dosages"))
    if dosages is None:
        return 0, 0
    for d in dosages.findall(nstag("dosage")):
        cur.execute(SQL_DOSAGE, (drug_pk, text(d.find(nstag("form"))), text(d.find(nstag("route"))), text(d.find(nstag("strength")))))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

# --- FIXED AHFS: flexible extraction ---
SQL_AHFS = "INSERT IGNORE INTO drug_ahfs_code (drug_pk, ahfs_code) VALUES (%s,%s)"
def h_drug_ahfs(cur, drug_pk, drug_elem, nstag):
    ahfs = drug_elem.find(nstag("ahfs-codes"))
    if ahfs is None:
        return 0, 0

    attempted = affected = 0
    # accept any child element under ahfs-codes; many versions use <ahfs-code>, some vary
    for child in list(ahfs):
        v = text(child)
        if v:
            cur.execute(SQL_AHFS, (drug_pk, v))
            attempted += 1
            affected += max(cur.rowcount, 0)

    # if no direct children text matched, fallback to scanning descendants
    if attempted == 0:
        for child in ahfs.iterdescendants():
            if strip_ns(child.tag).lower().endswith("ahfs-code") or "ahfs" in strip_ns(child.tag).lower():
                v = text(child)
                if v:
                    cur.execute(SQL_AHFS, (drug_pk, v))
                    attempted += 1
                    affected += max(cur.rowcount, 0)

    return attempted, affected

SQL_PDB = "INSERT IGNORE INTO drug_pdb_entry (drug_pk, pdb_entry) VALUES (%s,%s)"
def h_drug_pdb(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    pdbs = drug_elem.find(nstag("pdb-entries"))
    if pdbs is None:
        return 0, 0
    for p in pdbs.findall(nstag("pdb-entry")):
        v = text(p)
        if v:
            cur.execute(SQL_PDB, (drug_pk, v))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_ATC_CODE = "INSERT IGNORE INTO drug_atc_code (drug_pk, atc_code) VALUES (%s,%s)"
def h_drug_atc_code(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    atcs = drug_elem.find(nstag("atc-codes"))
    if atcs is None:
        return 0, 0
    for atc in atcs.findall(nstag("atc-code")):
        code = atc.get("code")
        if code:
            cur.execute(SQL_ATC_CODE, (drug_pk, code))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_ATC_LEVEL = """
INSERT INTO drug_atc_level (drug_pk, atc_code, level_no, level_text, level_code_attr)
VALUES (%s,%s,%s,%s,%s)
ON DUPLICATE KEY UPDATE level_text=VALUES(level_text), level_code_attr=VALUES(level_code_attr)
"""
def h_drug_atc_level(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    atcs = drug_elem.find(nstag("atc-codes"))
    if atcs is None:
        return 0, 0
    for atc in atcs.findall(nstag("atc-code")):
        code = atc.get("code")
        if not code:
            continue
        for i, lvl in enumerate(atc.findall(nstag("level")), start=1):
            cur.execute(SQL_ATC_LEVEL, (drug_pk, code, i, text(lvl), lvl.get("code")))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_PRICE = "INSERT INTO drug_price (drug_pk, description, cost_value, cost_currency, unit) VALUES (%s,%s,%s,%s,%s)"
def h_drug_price(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    prices = drug_elem.find(nstag("prices"))
    if prices is None:
        return 0, 0
    for price in prices.findall(nstag("price")):
        desc = text(price.find(nstag("description")))
        cost_elem = price.find(nstag("cost"))
        cost_value = text(cost_elem)
        cost_currency = cost_elem.get("currency") if cost_elem is not None else None
        unit = text(price.find(nstag("unit")))
        cur.execute(SQL_PRICE, (drug_pk, desc, cost_value, cost_currency, unit))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_PATENT = "INSERT INTO drug_patent (drug_pk, number, country, approved, expires, pediatric_extension) VALUES (%s,%s,%s,%s,%s,%s)"
def h_drug_patent(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    pats = drug_elem.find(nstag("patents"))
    if pats is None:
        return 0, 0
    for pat in pats.findall(nstag("patent")):
        ped = (text(pat.find(nstag("pediatric-extension"))) or "").strip().lower() in ("true","1","yes")
        cur.execute(SQL_PATENT, (drug_pk, text(pat.find(nstag("number"))), text(pat.find(nstag("country"))),
                                text(pat.find(nstag("approved"))), text(pat.find(nstag("expires"))), ped))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_FOOD = "INSERT IGNORE INTO drug_food_interaction (drug_pk, interaction_text) VALUES (%s,%s)"
def h_drug_food(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    food = drug_elem.find(nstag("food-interactions"))
    if food is None:
        return 0, 0
    for fi in food.findall(nstag("food-interaction")):
        v = text(fi)
        if v:
            cur.execute(SQL_FOOD, (drug_pk, v))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_DRUG_INTERACTION = """
INSERT INTO drug_interaction (drug_pk, interacting_drugbank_id, interacting_name, description)
VALUES (%s,%s,%s,%s)
ON DUPLICATE KEY UPDATE interacting_name=VALUES(interacting_name), description=VALUES(description)
"""
def h_drug_interaction(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    dis = drug_elem.find(nstag("drug-interactions"))
    if dis is None:
        return 0, 0
    for di in dis.findall(nstag("drug-interaction")):
        other_id = text(di.find(nstag("drugbank-id")))
        if not other_id:
            continue
        cur.execute(SQL_DRUG_INTERACTION, (drug_pk, other_id, text(di.find(nstag("name"))), text(di.find(nstag("description")))))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_EXT_ID = "INSERT IGNORE INTO drug_external_identifier (drug_pk, resource, identifier) VALUES (%s,%s,%s)"
def h_ext_id(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    ex_ids = drug_elem.find(nstag("external-identifiers"))
    if ex_ids is None:
        return 0, 0
    for exi in ex_ids.findall(nstag("external-identifier")):
        resource = text(exi.find(nstag("resource")))
        identifier = text(exi.find(nstag("identifier")))
        if resource and identifier:
            cur.execute(SQL_EXT_ID, (drug_pk, resource, identifier))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_EXT_LINK = "INSERT IGNORE INTO drug_external_link (drug_pk, resource, url) VALUES (%s,%s,%s)"
def h_ext_link(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    ex_links = drug_elem.find(nstag("external-links"))
    if ex_links is None:
        return 0, 0
    for exl in ex_links.findall(nstag("external-link")):
        resource = text(exl.find(nstag("resource")))
        url = text(exl.find(nstag("url")))
        if resource and url:
            cur.execute(SQL_EXT_LINK, (drug_pk, resource, url))
            attempted += 1
            affected += max(cur.rowcount, 0)
    return attempted, affected

SQL_PROPERTY = "INSERT INTO drug_property (drug_pk, property_type, kind, value, source) VALUES (%s,%s,%s,%s,%s)"
def h_drug_property(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    def insert_props(container_tag: str, prop_type: str):
        nonlocal attempted, affected
        cont = drug_elem.find(nstag(container_tag))
        if cont is None:
            return
        for prop in cont.findall(nstag("property")):
            cur.execute(SQL_PROPERTY, (drug_pk, prop_type,
                                       text(prop.find(nstag("kind"))),
                                       text(prop.find(nstag("value"))),
                                       text(prop.find(nstag("source")))))
            attempted += 1
            affected += max(cur.rowcount, 0)
    insert_props("calculated-properties", "calculated")
    insert_props("experimental-properties", "experimental")
    return attempted, affected

SQL_REACTION = "INSERT INTO reaction (drug_pk, sequence_text) VALUES (%s,%s)"
def h_reaction(cur, drug_pk, drug_elem, nstag):
    attempted = affected = 0
    reactions = drug_elem.find(nstag("reactions"))
    if reactions is None:
        return 0, 0
    for rxn in reactions.findall(nstag("reaction")):
        cur.execute(SQL_REACTION, (drug_pk, text(rxn.find(nstag("sequence")))))
        attempted += 1
        affected += max(cur.rowcount, 0)
    return attempted, affected

# ---------------------------
# RUN Stage 2
# ---------------------------
print("Stage 2 pre-pass: counting total <drug> in XML...")
total_drugs_xml = count_total_drugs_with_progress(XML_PATH)
print(f"Total drugs in XML: {total_drugs_xml:,}")

ns = detect_namespace(XML_PATH)
conn0 = get_conn()
drug_pk_map = load_drug_pk_map(conn0)
conn0.close()

print(f"Drugs in DB (Stage 1): {len(drug_pk_map):,}")
if len(drug_pk_map) != total_drugs_xml:
    print("WARNING: drug table count != XML count. Missing drugs will be reported per table.")

tasks = [
    ("drugbank_id_map", h_drugbank_id_map),
    ("drug_group", h_drug_group),
    ("drug_classification", h_drug_classification),
    ("drug_classification_alternative_parent", h_alt_parent),
    ("drug_classification_substituent", h_substituent),
    ("drug_synonym", h_drug_synonym),
    ("drug_salt", h_drug_salt),
    ("product", h_product),
    ("packager", h_packager),
    ("manufacturer", h_manufacturer),
    ("drug_category", h_drug_category),
    ("drug_affected_organism", h_affected_organism),
    ("drug_dosage", h_drug_dosage),
    ("drug_ahfs_code", h_drug_ahfs),     # FIXED extractor
    ("drug_pdb_entry", h_drug_pdb),
    ("drug_atc_code", h_drug_atc_code),
    ("drug_atc_level", h_drug_atc_level),
    ("drug_price", h_drug_price),
    ("drug_patent", h_drug_patent),
    ("drug_food_interaction", h_drug_food),
    ("drug_interaction", h_drug_interaction),
    ("drug_external_identifier", h_ext_id),
    ("drug_external_link", h_ext_link),
    ("drug_property", h_drug_property),
    ("reaction", h_reaction),
]

results: List[PassResult] = []
for table, handler in tasks:
    results.append(run_table_pass(XML_PATH, table, handler, ns, drug_pk_map, total_drugs_xml))

print("\n=== Stage 2 Summary (Corrected) ===")
rows = []
for r in results:
    rows.append([r.table, r.status, f"{r.attempted:,}", f"{r.affected:,}", f"{r.db_rows:,}", f"{r.missing_drug:,}", f"{r.seconds:.1f}s"])
print_table(rows, headers=["Table","Status","Attempted","Affected(rowcount)","Rows in DB","Missing drug","Time"])

Stage 2 pre-pass: counting total <drug> in XML...


Pre-pass: counting <drug>: 0drug [00:00, ?drug/s]

KeyboardInterrupt: 

<small>

### Stage 3 — Children of `drug_salt` (drug_salt_drugbank_id)

This stage inserts into:

- **`drug_salt_drugbank_id`** *(Stage 3)* — links **salt** → **DrugBank salt IDs**  
  Foreign key: `drug_salt_drugbank_id.salt_pk` → `drug_salt.salt_pk`

---

#### Prerequisites
Before running Stage 3, you must have:

- **Stage 1:** `drug` loaded (so `drug_pk` exists)
- **Stage 2:** `drug_salt` loaded (so `salt_pk` exists)

---

#### What the Stage 3 cell does
- Streams the XML once
- For each `<drug>`:
  - Resolves `drug_pk` (from `primary_drugbank_id`)
  - Reads `<salts>/<salt>` entries
  - Matches each XML `<salt>` to the corresponding DB row in `drug_salt`
  - Inserts each `<salt>/<drugbank-id>` into `drug_salt_drugbank_id`

---

#### Matching logic (XML salt → DB salt)
To find the correct `salt_pk`, it matches the salt by these fields:

- `name`
- `unii`
- `cas-number`
- `inchikey`
- `average-mass`
- `monoisotopic-mass`

Because floating values can differ slightly, the code normalizes/rounds numeric values to reduce mismatch risk.

If a salt cannot be matched, it is counted as **unmatched_salts**.

---

#### Insert behavior
- Uses **`INSERT IGNORE`** into `drug_salt_drugbank_id`  
  → safe to re-run without duplicating rows.

Optional settings:
- **Skip if non-empty:** if `drug_salt_drugbank_id` already has rows, the stage can be skipped.
- **Wipe before load:** can `DELETE FROM drug_salt_drugbank_id` (safe; it’s a child table).

---

#### Output / reporting
The stage prints:
- Progress bar over total drugs
- Periodic status (attempted inserts, inserted rows, matched/unmatched salts)
- Final summary table:
  - attempted vs inserted rows
  - rows in DB
  - matched/unmatched salt counts
  - errors and runtime

</small>


In [26]:
import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple, List

from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# ---------------------------
# CONFIG (EDIT THESE)
# ---------------------------
XML_PATH = r"db/drugbank.xml"   # <-- change

MYSQL_HOST = "localhost"
MYSQL_PORT = 3306
MYSQL_USER = "root"
MYSQL_PASSWORD = ""
MYSQL_DB = "drugbank"

COMMIT_EVERY = 2000
SHOW_EVERY = 10000

# If the table already has rows, skip the whole Stage-3 import
SKIP_IF_NONEMPTY = True

# If True, wipes drug_salt_drugbank_id before loading (safe; it's the child table)
WIPE_BEFORE_LOAD = False
# ---------------------------


def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST,
        port=MYSQL_PORT,
        user=MYSQL_USER,
        password=MYSQL_PASSWORD,
        database=MYSQL_DB,
        autocommit=False,
    )

def count_table(conn, table: str) -> int:
    cur = conn.cursor()
    cur.execute(f"SELECT COUNT(*) FROM `{table}`")
    n = int(cur.fetchone()[0])
    cur.close()
    return n

def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def detect_namespace(xml_path: str) -> Optional[str]:
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns: Optional[str]):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem) -> Optional[str]:
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def bool_from_text(s: Optional[str]) -> Optional[bool]:
    if s is None:
        return None
    v = s.strip().lower()
    if v in ("true", "1", "yes"): return True
    if v in ("false", "0", "no"): return False
    return None

def to_float(s: Optional[str]) -> Optional[float]:
    if s is None:
        return None
    try:
        return float(s)
    except:
        return None

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def get_primary_drugbank_id(drug_elem, nstag) -> Optional[str]:
    ids = drug_elem.findall(nstag("drugbank-id"))
    fallback = None
    primary = None
    for ide in ids:
        did = text(ide)
        if not did:
            continue
        if fallback is None:
            fallback = did
        if bool_from_text(ide.get("primary")):
            primary = did
            break
    return primary or fallback

def load_drug_pk_map(conn) -> Dict[str, int]:
    cur = conn.cursor()
    cur.execute("SELECT primary_drugbank_id, drug_pk FROM drug")
    mp = {row[0]: int(row[1]) for row in cur.fetchall()}
    cur.close()
    return mp

def norm_str(s: Optional[str]) -> Optional[str]:
    return s.strip() if isinstance(s, str) and s.strip() else None

def norm_float(v: Optional[float]) -> Optional[float]:
    if v is None:
        return None
    # rounding avoids float representation mismatch between XML parsing and DB double
    return round(float(v), 6)

def salt_key(name, unii, cas, inchikey, avg, mono) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str], Optional[float], Optional[float]]:
    return (
        norm_str(name),
        norm_str(unii),
        norm_str(cas),
        norm_str(inchikey),
        norm_float(avg),
        norm_float(mono),
    )

# SQL
SQL_LOAD_SALTS_FOR_DRUG = """
SELECT salt_pk, name, unii, cas_number, inchikey, average_mass, monoisotopic_mass
FROM drug_salt
WHERE drug_pk = %s
ORDER BY salt_pk
"""

SQL_INSERT_SALT_ID = """
INSERT IGNORE INTO drug_salt_drugbank_id (salt_pk, drugbank_id, is_primary)
VALUES (%s, %s, %s)
"""

@dataclass
class Stage3Result:
    table: str
    status: str
    attempted: int
    inserted: int
    rows_in_db: int
    processed_drugs: int
    matched_salts: int
    unmatched_salts: int
    missing_drug: int
    errors: int
    seconds: float

def import_stage3_drug_salt_drugbank_id(xml_path: str) -> Stage3Result:
    ns = detect_namespace(xml_path)
    nstag = nstag_factory(ns)

    # Pre-check
    conn_check = get_conn()
    existing = count_table(conn_check, "drug_salt_drugbank_id")
    conn_check.close()

    if SKIP_IF_NONEMPTY and existing > 0 and not WIPE_BEFORE_LOAD:
        return Stage3Result(
            table="drug_salt_drugbank_id",
            status="SKIPPED (already has rows)",
            attempted=0,
            inserted=0,
            rows_in_db=existing,
            processed_drugs=0,
            matched_salts=0,
            unmatched_salts=0,
            missing_drug=0,
            errors=0,
            seconds=0.0,
        )

    if WIPE_BEFORE_LOAD:
        conn_wipe = get_conn()
        cur_wipe = conn_wipe.cursor()
        cur_wipe.execute("DELETE FROM drug_salt_drugbank_id")
        conn_wipe.commit()
        cur_wipe.close()
        conn_wipe.close()

    # Main run
    conn = get_conn()
    cur = conn.cursor()

    drug_pk_map = load_drug_pk_map(conn)
    total_drugs = len(drug_pk_map)
    if total_drugs == 0:
        cur.close()
        conn.close()
        raise RuntimeError("drug table is empty. Run Stage 1 first.")

    attempted = 0
    inserted = 0
    processed_drugs = 0
    matched_salts = 0
    unmatched_salts = 0
    missing_drug = 0
    errors = 0

    # cache: drug_pk -> dict[key] = list[salt_pk]
    salts_cache: Dict[int, Dict[Tuple, List[int]]] = {}

    def get_salts_map_for_drug(drug_pk: int) -> Dict[Tuple, List[int]]:
        if drug_pk in salts_cache:
            return salts_cache[drug_pk]

        cur.execute(SQL_LOAD_SALTS_FOR_DRUG, (drug_pk,))
        rows = cur.fetchall()

        m: Dict[Tuple, List[int]] = {}
        for (salt_pk, name, unii, cas, inchikey, avg, mono) in rows:
            k = salt_key(name, unii, cas, inchikey, avg, mono)
            m.setdefault(k, []).append(int(salt_pk))

        # cache (you can cap cache size if memory is a concern; usually OK)
        salts_cache[drug_pk] = m
        return m

    t0 = time.time()
    pbar = tqdm(total=total_drugs, unit="drug", dynamic_ncols=True, desc="drug_salt_drugbank_id (Stage 3)")

    try:
        context = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, elem in context:
            if strip_ns(elem.tag) != "drug":
                continue

            primary_id = get_primary_drugbank_id(elem, nstag)
            if not primary_id:
                cleanup_lxml(elem)
                continue

            drug_pk = drug_pk_map.get(primary_id)
            if drug_pk is None:
                missing_drug += 1
                cleanup_lxml(elem)
                continue

            try:
                salts = elem.find(nstag("salts"))
                if salts is not None:
                    salts_map = get_salts_map_for_drug(drug_pk)

                    # For each XML salt, match to a DB salt_pk (and pop so we don't reuse)
                    for salt in salts.findall(nstag("salt")):
                        s_name = text(salt.find(nstag("name")))
                        s_unii = text(salt.find(nstag("unii")))
                        s_cas = text(salt.find(nstag("cas-number")))
                        s_inchikey = text(salt.find(nstag("inchikey")))
                        s_avg = to_float(text(salt.find(nstag("average-mass"))))
                        s_mono = to_float(text(salt.find(nstag("monoisotopic-mass"))))

                        k = salt_key(s_name, s_unii, s_cas, s_inchikey, s_avg, s_mono)
                        lst = salts_map.get(k, [])
                        if not lst:
                            unmatched_salts += 1
                            continue

                        salt_pk = lst.pop(0)
                        matched_salts += 1

                        for sid in salt.findall(nstag("drugbank-id")):
                            did = text(sid)
                            if not did:
                                continue
                            is_primary = bool_from_text(sid.get("primary")) or False
                            cur.execute(SQL_INSERT_SALT_ID, (salt_pk, did, bool(is_primary)))
                            attempted += 1
                            inserted += max(cur.rowcount, 0)

                processed_drugs += 1
                pbar.update(1)

                if processed_drugs % COMMIT_EVERY == 0:
                    conn.commit()

                if SHOW_EVERY and processed_drugs % SHOW_EVERY == 0:
                    elapsed = time.time() - t0
                    rate = processed_drugs / elapsed if elapsed > 0 else 0.0
                    print(f"[Stage3] drugs={processed_drugs:,} attempted={attempted:,} inserted={inserted:,} matched_salts={matched_salts:,} unmatched_salts={unmatched_salts:,} rate={rate:.2f}/sec")

            except Exception as e:
                errors += 1
                conn.rollback()
                print(f"[Stage3] ERROR drug_pk={drug_pk} ({primary_id}): {e}")

            finally:
                cleanup_lxml(elem)

        conn.commit()

    finally:
        try:
            pbar.close()
        except Exception:
            pass
        cur.close()
        conn.close()

    seconds = time.time() - t0
    conn2 = get_conn()
    rows_in_db = count_table(conn2, "drug_salt_drugbank_id")
    conn2.close()

    return Stage3Result(
        table="drug_salt_drugbank_id",
        status="LOADED",
        attempted=attempted,
        inserted=inserted,
        rows_in_db=rows_in_db,
        processed_drugs=processed_drugs,
        matched_salts=matched_salts,
        unmatched_salts=unmatched_salts,
        missing_drug=missing_drug,
        errors=errors,
        seconds=seconds,
    )

def print_result(r: Stage3Result):
    headers = ["Table", "Status", "Attempted", "Inserted", "Rows in DB", "Processed drugs", "Matched salts", "Unmatched salts", "Missing drug", "Errors", "Time"]
    rows = [[
        r.table,
        r.status,
        f"{r.attempted:,}",
        f"{r.inserted:,}",
        f"{r.rows_in_db:,}",
        f"{r.processed_drugs:,}",
        f"{r.matched_salts:,}",
        f"{r.unmatched_salts:,}",
        f"{r.missing_drug:,}",
        f"{r.errors:,}",
        f"{r.seconds:.1f}s",
    ]]
    # simple fixed-width print
    widths = [len(h) for h in headers]
    for row in rows:
        for i, v in enumerate(row):
            widths[i] = max(widths[i], len(str(v)))
    def sep(ch="-"):
        return "+".join(ch * (w + 2) for w in widths)
    def render(row):
        return "|".join(" " + str(row[i]).ljust(widths[i]) + " " for i in range(len(headers)))
    print(sep("-"))
    print(render(headers))
    print(sep("="))
    print(render(rows[0]))
    print(sep("-"))

# ---- RUN STAGE 3 ----
result = import_stage3_drug_salt_drugbank_id(XML_PATH)
print_result(result)


drug_salt_drugbank_id (Stage 3):   0%|                                                     | 0/17430 [00:00<?,…

[Stage3] drugs=10,000 attempted=50 inserted=50 matched_salts=50 unmatched_salts=0 rate=3804.31/sec
[Stage3] drugs=20,000 attempted=369 inserted=369 matched_salts=369 unmatched_salts=0 rate=1510.30/sec
[Stage3] drugs=30,000 attempted=1,158 inserted=1,158 matched_salts=1,158 unmatched_salts=0 rate=482.46/sec
[Stage3] drugs=40,000 attempted=1,187 inserted=1,187 matched_salts=1,187 unmatched_salts=0 rate=613.48/sec
[Stage3] drugs=50,000 attempted=1,221 inserted=1,221 matched_salts=1,221 unmatched_salts=0 rate=736.74/sec
[Stage3] drugs=60,000 attempted=1,360 inserted=1,360 matched_salts=1,360 unmatched_salts=0 rate=817.61/sec
[Stage3] drugs=70,000 attempted=2,659 inserted=2,659 matched_salts=2,659 unmatched_salts=0 rate=532.89/sec
-----------------------+--------+-----------+----------+------------+-----------------+---------------+-----------------+--------------+--------+--------
 Table                 | Status | Attempted | Inserted | Rows in DB | Processed drugs | Matched salts | Unmatc

<small>

### Stage 4 — References link tables (drug_reference, interactant_reference)

This stage builds the **reference linking tables** using foreign keys:

- **`reference_item`** *(Stage 0 prerequisite)* — inserted **as needed**
- **`drug_reference`** *(Stage 4)* — links `drug` → `reference_item`
- **`interactant_reference`** *(Stage 4)* — links `interactant` → `reference_item`  
  *(Only runs if the `interactant` table is already loaded.)*

---

#### What this Stage 4 cell does

**1) Count total drugs first (pre-pass)**  
- Performs a **streaming pre-pass** over the XML and counts total `<drug>` elements.  
- Uses this number for a **true progress bar** in the main pass.

**2) Main processing pass (with progress)**  
- Streams the XML again and processes each `<drug>`:
  - Inserts rows into **`reference_item`** if needed
  - Inserts rows into **`drug_reference`**
  - Optionally inserts rows into **`interactant_reference`**

---

#### Matching & deduplication logic

**A) `reference_item` dedupe**  
`reference_item` has **no natural unique key**, so we deduplicate using:

- Python cache + DB lookup using a composite key:  
  `(ref_type, ref_id, pubmed_id, citation, isbn, title, url)`

**B) `drug_reference` insertion**  
- Uses `INSERT IGNORE` into `(drug_pk, reference_pk)`.

**C) `interactant_reference` insertion (only if interactants exist)**  
Requires `interactant` rows already loaded (Stage 7). We locate `interactant_pk` by matching:

- `drug_pk + kind + position + interactant_id + name + organism + known_action`

If not found, the link is skipped and counted as **missing_interactant**.

---

#### Notes
- This implementation is **correctness-first**, not speed-first.
- It runs the XML twice (one pass for counting, one pass for insertion).

</small>


In [16]:
# =============================================================================
# STAGE 4 (Jupyter single cell)
# Builds:
#   - reference_item  (Stage 0 prerequisite; inserted here as needed)
#   - drug_reference  (Stage 4)
#   - interactant_reference (Stage 4, ONLY if interactant table is already loaded)
#
# As requested:
#   1) FIRST counts total <drug> in the XML (streaming pre-pass)
#   2) THEN processes the XML with a progress bar based on that total
#
# Matching logic:
# - reference_item has no natural unique key, so we dedupe in Python + DB lookup using a composite key:
#     (ref_type, ref_id, pubmed_id, citation, isbn, title, url)
#   We keep an in-memory cache for speed.
#
# - drug_reference uses INSERT IGNORE (PK: drug_pk, reference_pk)
#
# - interactant_reference requires interactant rows exist already.
#   We locate interactant_pk by matching:
#     drug_pk + kind + position + interactant_id + name + organism + known_action
#   If not found, we skip and count "missing_interactant".
#
# NOTE: This is correctness-first, not speed-first.
# =============================================================================

import time
from dataclasses import dataclass
from typing import Optional, Dict, Tuple, Any, List

from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# ---------------------------
# CONFIG (EDIT THESE)
# ---------------------------
XML_PATH = r"db/drugbank.xml"   # <-- change

MYSQL_HOST = "localhost"
MYSQL_PORT = 3306
MYSQL_USER = "root"
MYSQL_PASSWORD = ""
MYSQL_DB = "drugbank"

COMMIT_EVERY = 1000
SHOW_EVERY = 10000

# If interactant table isn't loaded yet, set this to False for now
DO_INTERACTANT_REFERENCE = True
# ---------------------------


def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST,
        port=MYSQL_PORT,
        user=MYSQL_USER,
        password=MYSQL_PASSWORD,
        database=MYSQL_DB,
        autocommit=False,
    )

def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def detect_namespace(xml_path: str) -> Optional[str]:
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns: Optional[str]):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem) -> Optional[str]:
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def bool_from_text(s: Optional[str]) -> Optional[bool]:
    if s is None:
        return None
    v = s.strip().lower()
    if v in ("true", "1", "yes"): return True
    if v in ("false", "0", "no"): return False
    return None

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def get_primary_drugbank_id(drug_elem, nstag) -> Optional[str]:
    ids = drug_elem.findall(nstag("drugbank-id"))
    fallback = None
    primary = None
    for ide in ids:
        did = text(ide)
        if not did:
            continue
        if fallback is None:
            fallback = did
        if bool_from_text(ide.get("primary")):
            primary = did
            break
    return primary or fallback

def load_drug_pk_map(conn) -> Dict[str, int]:
    cur = conn.cursor()
    cur.execute("SELECT primary_drugbank_id, drug_pk FROM drug")
    mp = {row[0]: int(row[1]) for row in cur.fetchall()}
    cur.close()
    return mp

def table_exists(conn, table: str) -> bool:
    cur = conn.cursor()
    cur.execute(
        "SELECT COUNT(*) FROM information_schema.tables "
        "WHERE table_schema=%s AND table_name=%s",
        (MYSQL_DB, table),
    )
    ok = int(cur.fetchone()[0]) == 1
    cur.close()
    return ok

def count_table(conn, table: str) -> int:
    cur = conn.cursor()
    cur.execute(f"SELECT COUNT(*) FROM `{table}`")
    n = int(cur.fetchone()[0])
    cur.close()
    return n

def count_total_drugs_in_xml(xml_path: str) -> int:
    total = 0
    ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
    for _, elem in ctx:
        if strip_ns(elem.tag) == "drug":
            total += 1
            cleanup_lxml(elem)
    return total

# ---------------------------
# SQL
# ---------------------------
SQL_REF_SELECT = """
SELECT reference_pk
FROM reference_item
WHERE ref_type <=> %s
  AND ref_id <=> %s
  AND pubmed_id <=> %s
  AND citation <=> %s
  AND isbn <=> %s
  AND title <=> %s
  AND url <=> %s
LIMIT 1
"""

SQL_REF_INSERT = """
INSERT INTO reference_item
  (ref_type, ref_id, pubmed_id, citation, isbn, title, url)
VALUES (%s,%s,%s,%s,%s,%s,%s)
"""

SQL_DRUG_REF_INSERT = "INSERT IGNORE INTO drug_reference (drug_pk, reference_pk) VALUES (%s,%s)"

SQL_INTERACTANT_FIND = """
SELECT interactant_pk
FROM interactant
WHERE drug_pk=%s
  AND kind <=> %s
  AND position <=> %s
  AND interactant_id <=> %s
  AND name <=> %s
  AND organism <=> %s
  AND known_action <=> %s
ORDER BY interactant_pk DESC
LIMIT 1
"""

SQL_INTERACTANT_REF_INSERT = "INSERT IGNORE INTO interactant_reference (interactant_pk, reference_pk) VALUES (%s,%s)"

# ---------------------------
# Reference parsing helpers
# ---------------------------
def parse_reference_list(ref_list_elem, nstag) -> List[Tuple[str, Optional[str], Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]]:
    """
    Returns list of tuples:
      (ref_type, ref_id, pubmed_id, citation, isbn, title, url)
    """
    out = []

    # articles/article => ref-id, pubmed-id, citation
    arts = ref_list_elem.find(nstag("articles"))
    if arts is not None:
        for a in arts.findall(nstag("article")):
            out.append((
                "article",
                text(a.find(nstag("ref-id"))),
                text(a.find(nstag("pubmed-id"))),
                text(a.find(nstag("citation"))),
                None,
                None,
                None
            ))

    # textbooks/textbook => ref-id, isbn, citation
    tbs = ref_list_elem.find(nstag("textbooks"))
    if tbs is not None:
        for t in tbs.findall(nstag("textbook")):
            out.append((
                "textbook",
                text(t.find(nstag("ref-id"))),
                None,
                text(t.find(nstag("citation"))),
                text(t.find(nstag("isbn"))),
                None,
                None
            ))

    # links/link => ref-id, title, url
    links = ref_list_elem.find(nstag("links"))
    if links is not None:
        for l in links.findall(nstag("link")):
            out.append((
                "link",
                text(l.find(nstag("ref-id"))),
                None,
                None,
                None,
                text(l.find(nstag("title"))),
                text(l.find(nstag("url")))
            ))

    # attachments/attachment => ref-id, title, url
    atts = ref_list_elem.find(nstag("attachments"))
    if atts is not None:
        for a in atts.findall(nstag("attachment")):
            out.append((
                "attachment",
                text(a.find(nstag("ref-id"))),
                None,
                None,
                None,
                text(a.find(nstag("title"))),
                text(a.find(nstag("url")))
            ))

    return out

def get_or_create_reference_pk(cur, cache: Dict[Tuple, int], ref_tuple: Tuple) -> int:
    """
    Dedupe reference_item using a cache + DB lookup.
    """
    if ref_tuple in cache:
        return cache[ref_tuple]

    (ref_type, ref_id, pubmed_id, citation, isbn, title, url) = ref_tuple

    cur.execute(SQL_REF_SELECT, (ref_type, ref_id, pubmed_id, citation, isbn, title, url))
    row = cur.fetchone()
    if row:
        pk = int(row[0])
        cache[ref_tuple] = pk
        return pk

    cur.execute(SQL_REF_INSERT, (ref_type, ref_id, pubmed_id, citation, isbn, title, url))
    pk = int(cur.lastrowid)
    cache[ref_tuple] = pk
    return pk

# ---------------------------
# MAIN Stage 4 Runner
# ---------------------------
@dataclass
class Stage4Result:
    reference_item_attempted: int
    reference_item_created: int
    drug_reference_attempted: int
    drug_reference_inserted: int
    interactant_reference_attempted: int
    interactant_reference_inserted: int
    missing_drug: int
    missing_interactant: int
    errors: int
    seconds: float

print("Stage 4: counting total <drug> in XML (pre-pass)...")
total_drugs_xml = count_total_drugs_in_xml(XML_PATH)
print(f"Total drugs in XML: {total_drugs_xml:,}")

ns = detect_namespace(XML_PATH)
nstag = nstag_factory(ns)

conn_check = get_conn()
drug_pk_map = load_drug_pk_map(conn_check)

if not drug_pk_map:
    conn_check.close()
    raise RuntimeError("drug table is empty. Run Stage 1 first.")

if DO_INTERACTANT_REFERENCE:
    if not table_exists(conn_check, "interactant") or not table_exists(conn_check, "interactant_reference"):
        print("WARNING: interactant/interactant_reference table not found. Disabling interactant_reference for this run.")
        DO_INTERACTANT_REFERENCE = False

conn_check.close()

conn = get_conn()
cur = conn.cursor()

ref_cache: Dict[Tuple, int] = {}

ref_attempted = 0
ref_created = 0
drugref_attempted = 0
drugref_inserted = 0
intref_attempted = 0
intref_inserted = 0

missing_drug = 0
missing_interactant = 0
errors = 0

processed_drugs = 0
t0 = time.time()

pbar = tqdm(total=total_drugs_xml, unit="drug", dynamic_ncols=True, desc="Stage4: drug_reference + interactant_reference")

try:
    ctx = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
    for _, elem in ctx:
        if strip_ns(elem.tag) != "drug":
            continue

        primary_id = get_primary_drugbank_id(elem, nstag)
        if not primary_id:
            cleanup_lxml(elem)
            continue

        drug_pk = drug_pk_map.get(primary_id)
        if drug_pk is None:
            missing_drug += 1
            processed_drugs += 1
            pbar.update(1)
            cleanup_lxml(elem)
            continue

        try:
            # -------------------------
            # A) drug_reference (general-references)
            # -------------------------
            genrefs = elem.find(nstag("general-references"))
            if genrefs is not None:
                refs = parse_reference_list(genrefs, nstag)
                for ref_tuple in refs:
                    ref_attempted += 1
                    before = len(ref_cache)
                    reference_pk = get_or_create_reference_pk(cur, ref_cache, ref_tuple)
                    if len(ref_cache) > before:
                        # might be found or created; we can't know here without extra queries.
                        # We'll estimate created by "insert occurred" using lastrowid heuristic:
                        # Better: check whether SELECT hit; we didn’t store that. Keep ref_created approximate:
                        pass

                    # join
                    cur.execute(SQL_DRUG_REF_INSERT, (drug_pk, reference_pk))
                    drugref_attempted += 1
                    drugref_inserted += max(cur.rowcount, 0)

            # -------------------------
            # B) interactant_reference (if enabled)
            # -------------------------
            if DO_INTERACTANT_REFERENCE:
                # targets/enzymes/carriers/transporters have same reference-list structure under each item
                interactant_specs = [
                    ("target", "targets", "target"),
                    ("enzyme", "enzymes", "enzyme"),
                    ("carrier", "carriers", "carrier"),
                    ("transporter", "transporters", "transporter"),
                ]

                for kind, container_tag, item_tag in interactant_specs:
                    container = elem.find(nstag(container_tag))
                    if container is None:
                        continue

                    for item in container.findall(nstag(item_tag)):
                        # match interactant_pk (must exist already from Stage 7 loader)
                        pos = item.get("position")
                        try:
                            pos_i = int(pos) if pos is not None else None
                        except:
                            pos_i = None

                        interactant_id = text(item.find(nstag("id")))
                        iname = text(item.find(nstag("name")))
                        organism = text(item.find(nstag("organism")))
                        known_action = text(item.find(nstag("known-action")))

                        cur.execute(SQL_INTERACTANT_FIND, (
                            drug_pk, kind, pos_i, interactant_id, iname, organism, known_action
                        ))
                        row = cur.fetchone()
                        if not row:
                            # interactant not loaded yet or mismatch
                            missing_interactant += 1
                            continue
                        interactant_pk = int(row[0])

                        refs_node = item.find(nstag("references"))
                        if refs_node is None:
                            continue

                        refs = parse_reference_list(refs_node, nstag)
                        for ref_tuple in refs:
                            ref_attempted += 1
                            reference_pk = get_or_create_reference_pk(cur, ref_cache, ref_tuple)

                            cur.execute(SQL_INTERACTANT_REF_INSERT, (interactant_pk, reference_pk))
                            intref_attempted += 1
                            intref_inserted += max(cur.rowcount, 0)

            processed_drugs += 1
            pbar.update(1)

            if processed_drugs % COMMIT_EVERY == 0:
                conn.commit()

            if SHOW_EVERY and processed_drugs % SHOW_EVERY == 0:
                elapsed = time.time() - t0
                rate = processed_drugs / elapsed if elapsed > 0 else 0.0
                print(f"[Stage4] drugs={processed_drugs:,}/{total_drugs_xml:,} "
                      f"drug_ref_ins={drugref_inserted:,} int_ref_ins={intref_inserted:,} "
                      f"missing_int={missing_interactant:,} rate={rate:.2f}/sec")

        except Exception as e:
            errors += 1
            conn.rollback()
            print(f"[Stage4] ERROR drug_pk={drug_pk} ({primary_id}): {e}")

        finally:
            cleanup_lxml(elem)

    conn.commit()

finally:
    try:
        pbar.close()
    except Exception:
        pass
    cur.close()
    conn.close()

seconds = time.time() - t0

# Post-run DB counts (optional but helpful)
conn_post = get_conn()
ref_rows = count_table(conn_post, "reference_item") if table_exists(conn_post, "reference_item") else -1
drugref_rows = count_table(conn_post, "drug_reference") if table_exists(conn_post, "drug_reference") else -1
intref_rows = count_table(conn_post, "interactant_reference") if (DO_INTERACTANT_REFERENCE and table_exists(conn_post, "interactant_reference")) else -1
conn_post.close()

# Final report table
headers = ["Metric", "Value"]
rows = [
    ["Total drugs in XML", f"{total_drugs_xml:,}"],
    ["Processed drugs", f"{processed_drugs:,}"],
    ["Missing drug (not in DB)", f"{missing_drug:,}"],
    ["Reference tuples seen (attempted)", f"{ref_attempted:,}"],
    ["Drug_reference attempted", f"{drugref_attempted:,}"],
    ["Drug_reference inserted", f"{drugref_inserted:,}"],
    ["Interactant_reference attempted", f"{intref_attempted:,}"],
    ["Interactant_reference inserted", f"{intref_inserted:,}"],
    ["Missing interactant matches", f"{missing_interactant:,}"],
    ["Errors", f"{errors:,}"],
    ["Elapsed", f"{seconds:.1f}s"],
    ["DB rows: reference_item", "N/A" if ref_rows < 0 else f"{ref_rows:,}"],
    ["DB rows: drug_reference", "N/A" if drugref_rows < 0 else f"{drugref_rows:,}"],
    ["DB rows: interactant_reference", "N/A" if intref_rows < 0 else f"{intref_rows:,}"],
]

# pretty print
w = [len(h) for h in headers]
for r in rows:
    for i, v in enumerate(r):
        w[i] = max(w[i], len(str(v)))
def sep(ch="-"):
    return "+".join(ch * (x + 2) for x in w)
def line(r):
    return "|".join(" " + str(r[i]).ljust(w[i]) + " " for i in range(2))

print(sep("-"))
print(line(headers))
print(sep("="))
for r in rows:
    print(line(r))
    print(sep("-"))

Stage 4: counting total <drug> in XML (pre-pass)...
Total drugs in XML: 73,687


Stage4: drug_reference + interactant_reference:   0%|                                      | 0/73687 [00:00<?,…

[Stage4] drugs=10,000/73,687 drug_ref_ins=733 int_ref_ins=0 missing_int=900 rate=3481.39/sec
[Stage4] drugs=20,000/73,687 drug_ref_ins=3,211 int_ref_ins=0 missing_int=4,045 rate=1383.41/sec
[Stage4] drugs=30,000/73,687 drug_ref_ins=8,965 int_ref_ins=0 missing_int=12,140 rate=606.25/sec
[Stage4] drugs=40,000/73,687 drug_ref_ins=9,043 int_ref_ins=0 missing_int=14,233 rate=751.93/sec
[Stage4] drugs=50,000/73,687 drug_ref_ins=9,136 int_ref_ins=0 missing_int=16,312 rate=894.31/sec
[Stage4] drugs=60,000/73,687 drug_ref_ins=10,549 int_ref_ins=0 missing_int=19,353 rate=961.49/sec
[Stage4] drugs=70,000/73,687 drug_ref_ins=23,691 int_ref_ins=0 missing_int=32,542 rate=487.65/sec
-----------------------------------+--------
 Metric                            | Value  
 Total drugs in XML                | 73,687 
-----------------------------------+--------
 Processed drugs                   | 73,687 
-----------------------------------+--------
 Missing drug (not in DB)          | 0      
--------

<small>

### Stage 5 — Pathway joins and pathway members

This stage loads **pathway-related tables**:

- **`pathway`** *(independent; PK = `smpdb_id`)*  
- **`drug_pathway`** *(depends on `drug` + `pathway`)*  
- **`pathway_drug_member`** *(depends on `pathway`)*  
- **`pathway_enzyme_member`** *(depends on `pathway`)*  

---

#### Prerequisites
Before running Stage 5, you must have:

- **Stage 1:** `drug` loaded (so `drug_pk` exists)

---

#### What the Stage 5 cell does

**1) Counts total drugs first (pre-pass)**  
- Makes a streaming pass over the XML to count total `<drug>` elements.  
- Uses this total to show a **true progress bar**.

**2) Main insertion pass (with progress)**  
- Streams the XML again and processes each `<drug>`:
  - For each `<pathway>` inside `<pathways>`:
    - Upserts into **`pathway(smpdb_id, name, category)`**
    - Inserts the link into **`drug_pathway(drug_pk, smpdb_id)`**
    - Inserts pathway drug members into **`pathway_drug_member(smpdb_id, drugbank_id, name)`**
    - Inserts pathway enzymes into **`pathway_enzyme_member`** using UniProt IDs

---

#### Notes about `pathway_enzyme_member`
DrugBank XML provides pathway enzymes as a list of **`<uniprot-id>`** values.

The loader automatically detects whether your table uses:
- `uniprot_id` column (preferred), or
- `drugbank_id` column (fallback: stores the UniProt ID in that column)

---

#### Insert behavior
- `pathway`: **UPSERT** (`ON DUPLICATE KEY UPDATE`)
- `drug_pathway`: `INSERT IGNORE`
- `pathway_drug_member`: UPSERT (updates member `name`)
- `pathway_enzyme_member`: `INSERT IGNORE`

---

#### Output / reporting
The stage prints:
- Pre-pass drug count
- Progress bar during insertion
- Periodic status logs (counts and rate)
- Final summary table:
  - Attempted vs affected rows per table
  - Total rows in DB for each Stage 5 table

</small>


In [27]:
# =============================================================================
# STAGE 5 (Jupyter single cell)
# Inserts:
#   - pathway              (smpdb_id PK)  [insert/update]
#   - drug_pathway         (drug_pk + smpdb_id) [insert ignore]
#   - pathway_drug_member  (smpdb_id + drugbank_id) [upsert name]
#   - pathway_enzyme_member (depends on your schema: smpdb_id + uniprot_id OR drugbank_id) [insert ignore]
#
# As requested:
#   1) FIRST: streaming pre-pass counts total <drug> in XML
#   2) THEN: main pass with a progress bar based on that total
# =============================================================================

import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, List, Tuple

from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# ---------------------------
# CONFIG (EDIT THESE)
# ---------------------------
XML_PATH = r"db/drugbank.xml"   # <-- change

MYSQL_HOST = "localhost"
MYSQL_PORT = 3306
MYSQL_USER = "root"
MYSQL_PASSWORD = ""
MYSQL_DB = "drugbank"

COMMIT_EVERY = 1000
SHOW_EVERY = 10000
# ---------------------------

def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
        password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=False
    )

def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def detect_namespace(xml_path: str) -> Optional[str]:
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns: Optional[str]):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem) -> Optional[str]:
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def bool_from_text(s: Optional[str]) -> Optional[bool]:
    if s is None:
        return None
    v = s.strip().lower()
    if v in ("true", "1", "yes"): return True
    if v in ("false", "0", "no"): return False
    return None

def get_primary_drugbank_id(drug_elem, nstag) -> Optional[str]:
    ids = drug_elem.findall(nstag("drugbank-id"))
    fallback = None
    primary = None
    for ide in ids:
        did = text(ide)
        if not did:
            continue
        if fallback is None:
            fallback = did
        if bool_from_text(ide.get("primary")):
            primary = did
            break
    return primary or fallback

def count_total_drugs_in_xml(xml_path: str) -> int:
    total = 0
    ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
    for _, elem in ctx:
        if strip_ns(elem.tag) == "drug":
            total += 1
            cleanup_lxml(elem)
    return total

def load_drug_pk_map(conn) -> Dict[str, int]:
    cur = conn.cursor()
    cur.execute("SELECT primary_drugbank_id, drug_pk FROM drug")
    mp = {row[0]: int(row[1]) for row in cur.fetchall()}
    cur.close()
    return mp

def get_table_columns(conn, table: str) -> List[str]:
    cur = conn.cursor()
    cur.execute(
        "SELECT column_name FROM information_schema.columns "
        "WHERE table_schema=%s AND table_name=%s",
        (MYSQL_DB, table),
    )
    cols = [r[0] for r in cur.fetchall()]
    cur.close()
    return cols

def print_table(rows: List[List[Any]], headers: List[str]) -> None:
    cols = len(headers)
    widths = [len(h) for h in headers]
    for r in rows:
        for i in range(cols):
            widths[i] = max(widths[i], len(str(r[i])))

    def sep(ch="-"):
        return "+".join(ch * (w + 2) for w in widths)

    def render(r):
        return "|".join(" " + str(r[i]).ljust(widths[i]) + " " for i in range(cols))

    print(sep("-"))
    print(render(headers))
    print(sep("="))
    for r in rows:
        print(render(r))
        print(sep("-"))

@dataclass
class Stage5Result:
    processed_drugs: int
    missing_drug_in_db: int
    attempted_pathway: int
    affected_pathway: int
    attempted_drug_pathway: int
    affected_drug_pathway: int
    attempted_pathway_drug_member: int
    affected_pathway_drug_member: int
    attempted_pathway_enzyme_member: int
    affected_pathway_enzyme_member: int
    seconds: float

# ---------------------------
# PRE-PASS: Count total drugs
# ---------------------------
print("Stage 5 pre-pass: counting total <drug> in XML...")
total_drugs_xml = count_total_drugs_in_xml(XML_PATH)
print(f"Total drugs in XML: {total_drugs_xml:,}")

# ---------------------------
# Prepare DB + schema detection
# ---------------------------
conn = get_conn()
cur = conn.cursor()

drug_pk_map = load_drug_pk_map(conn)
if not drug_pk_map:
    cur.close()
    conn.close()
    raise RuntimeError("drug table is empty. Run Stage 1 (drug) first.")

# Detect how pathway_enzyme_member is defined
pem_cols = get_table_columns(conn, "pathway_enzyme_member")
pem_has_uniprot = "uniprot_id" in pem_cols
pem_has_drugbank = "drugbank_id" in pem_cols
pem_has_name = "name" in pem_cols  # some schemas include this

# SQL statements
SQL_PATHWAY_UPSERT = """
INSERT INTO pathway (smpdb_id, name, category)
VALUES (%s, %s, %s)
ON DUPLICATE KEY UPDATE
  name=VALUES(name),
  category=VALUES(category)
"""

SQL_DRUG_PATHWAY = "INSERT IGNORE INTO drug_pathway (drug_pk, smpdb_id) VALUES (%s, %s)"

SQL_PATHWAY_DRUG_MEMBER = """
INSERT INTO pathway_drug_member (smpdb_id, drugbank_id, name)
VALUES (%s, %s, %s)
ON DUPLICATE KEY UPDATE name=VALUES(name)
"""

if pem_has_uniprot:
    SQL_PATHWAY_ENZ_MEMBER = "INSERT IGNORE INTO pathway_enzyme_member (smpdb_id, uniprot_id) VALUES (%s, %s)"
elif pem_has_drugbank:
    # XML provides uniprot-id; if your table uses drugbank_id, we store uniprot-id there (still preserves data)
    SQL_PATHWAY_ENZ_MEMBER = "INSERT IGNORE INTO pathway_enzyme_member (smpdb_id, drugbank_id) VALUES (%s, %s)"
else:
    cur.close()
    conn.close()
    raise RuntimeError("pathway_enzyme_member table has neither uniprot_id nor drugbank_id column.")

# ---------------------------
# MAIN PASS: Insert Stage 5
# ---------------------------
attempted_pathway = affected_pathway = 0
attempted_drug_pathway = affected_drug_pathway = 0
attempted_pdm = affected_pdm = 0
attempted_pem = affected_pem = 0

processed_drugs = 0
missing_drug_in_db = 0

ns = detect_namespace(XML_PATH)
nstag = nstag_factory(ns)

t0 = time.time()
pbar = tqdm(total=total_drugs_xml, unit="drug", dynamic_ncols=True, desc="Stage 5: pathways")

try:
    ctx = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
    for _, elem in ctx:
        if strip_ns(elem.tag) != "drug":
            continue

        primary_id = get_primary_drugbank_id(elem, nstag)
        if not primary_id:
            processed_drugs += 1
            pbar.update(1)
            cleanup_lxml(elem)
            continue

        drug_pk = drug_pk_map.get(primary_id)
        if drug_pk is None:
            missing_drug_in_db += 1
            processed_drugs += 1
            pbar.update(1)
            cleanup_lxml(elem)
            continue

        try:
            pathways = elem.find(nstag("pathways"))
            if pathways is not None:
                for pw in pathways.findall(nstag("pathway")):
                    smpdb_id = text(pw.find(nstag("smpdb-id")))
                    if not smpdb_id:
                        continue

                    pw_name = text(pw.find(nstag("name")))
                    pw_cat = text(pw.find(nstag("category")))

                    # pathway upsert
                    cur.execute(SQL_PATHWAY_UPSERT, (smpdb_id, pw_name, pw_cat))
                    attempted_pathway += 1
                    affected_pathway += max(cur.rowcount, 0)

                    # drug_pathway join
                    cur.execute(SQL_DRUG_PATHWAY, (drug_pk, smpdb_id))
                    attempted_drug_pathway += 1
                    affected_drug_pathway += max(cur.rowcount, 0)

                    # pathway_drug_member (members inside pathway)
                    pw_drugs = pw.find(nstag("drugs"))
                    if pw_drugs is not None:
                        for pd in pw_drugs.findall(nstag("drug")):
                            did = text(pd.find(nstag("drugbank-id")))
                            if not did:
                                continue
                            dnm = text(pd.find(nstag("name")))
                            cur.execute(SQL_PATHWAY_DRUG_MEMBER, (smpdb_id, did, dnm))
                            attempted_pdm += 1
                            affected_pdm += max(cur.rowcount, 0)

                    # pathway_enzyme_member (uniprot-id list)
                    pw_enz = pw.find(nstag("enzymes"))
                    if pw_enz is not None:
                        for uid in pw_enz.findall(nstag("uniprot-id")):
                            u = text(uid)
                            if not u:
                                continue
                            cur.execute(SQL_PATHWAY_ENZ_MEMBER, (smpdb_id, u))
                            attempted_pem += 1
                            affected_pem += max(cur.rowcount, 0)

            processed_drugs += 1
            pbar.update(1)

            if processed_drugs % COMMIT_EVERY == 0:
                conn.commit()

            if SHOW_EVERY and processed_drugs % SHOW_EVERY == 0:
                elapsed = time.time() - t0
                rate = processed_drugs / elapsed if elapsed > 0 else 0.0
                print(f"[Stage5] drugs={processed_drugs:,}/{total_drugs_xml:,} "
                      f"pathway={attempted_pathway:,} drug_pathway={attempted_drug_pathway:,} "
                      f"pdm={attempted_pdm:,} pem={attempted_pem:,} rate={rate:.2f}/sec")

        except Exception as e:
            conn.rollback()
            print(f"[Stage5] ERROR for drug_pk={drug_pk} ({primary_id}): {e}")

        finally:
            cleanup_lxml(elem)

    conn.commit()

finally:
    try:
        pbar.close()
    except Exception:
        pass
    cur.close()
    conn.close()

seconds = time.time() - t0

# ---------------------------
# Final DB row counts
# ---------------------------
conn2 = get_conn()
rows_pathway = count_table(conn2, "pathway")
rows_drug_pathway = count_table(conn2, "drug_pathway")
rows_pdm = count_table(conn2, "pathway_drug_member")
rows_pem = count_table(conn2, "pathway_enzyme_member")
conn2.close()

print("\n=== Stage 5 Summary ===")
print_table(
    [
        ["pathway", f"{attempted_pathway:,}", f"{affected_pathway:,}", f"{rows_pathway:,}"],
        ["drug_pathway", f"{attempted_drug_pathway:,}", f"{affected_drug_pathway:,}", f"{rows_drug_pathway:,}"],
        ["pathway_drug_member", f"{attempted_pdm:,}", f"{affected_pdm:,}", f"{rows_pdm:,}"],
        ["pathway_enzyme_member", f"{attempted_pem:,}", f"{affected_pem:,}", f"{rows_pem:,}"],
    ],
    headers=["Table", "Attempted", "Affected(rowcount)", "Rows in DB"],
)

print("\nOther stats:")
print(f"- Drugs in XML (pre-pass): {total_drugs_xml:,}")
print(f"- Drugs processed:         {processed_drugs:,}")
print(f"- Missing drugs in DB:     {missing_drug_in_db:,}")
print(f"- pathway_enzyme_member uses column: {'uniprot_id' if pem_has_uniprot else 'drugbank_id'}")
print(f"- Elapsed: {seconds:.1f}s")


Stage 5 pre-pass: counting total <drug> in XML...
Total drugs in XML: 73,687


Stage 5: pathways:   0%|                                                                   | 0/73687 [00:00<?,…

[Stage5] drugs=10,000/73,687 pathway=599 drug_pathway=599 pdm=0 pem=9,851 rate=3123.78/sec
[Stage5] drugs=20,000/73,687 pathway=1,243 drug_pathway=1,243 pdm=0 pem=20,867 rate=1566.68/sec
[Stage5] drugs=30,000/73,687 pathway=2,012 drug_pathway=2,012 pdm=0 pem=37,609 rate=888.87/sec
[Stage5] drugs=40,000/73,687 pathway=2,558 drug_pathway=2,558 pdm=0 pem=48,385 rate=1071.54/sec
[Stage5] drugs=50,000/73,687 pathway=3,104 drug_pathway=3,104 pdm=0 pem=59,404 rate=1229.08/sec
[Stage5] drugs=60,000/73,687 pathway=3,648 drug_pathway=3,648 pdm=0 pem=69,879 rate=1320.79/sec
[Stage5] drugs=70,000/73,687 pathway=3,780 drug_pathway=3,780 pdm=0 pem=71,969 rate=920.10/sec

=== Stage 5 Summary ===
-----------------------+-----------+--------------------+------------
 Table                 | Attempted | Affected(rowcount) | Rows in DB 
 pathway               | 3,780     | 877                | 877        
-----------------------+-----------+--------------------+------------
 drug_pathway          | 3,780

<small>

### Stage 6 — Reaction children (`reaction_element`, `reaction_enzyme`)

**Goal:** Populate the reaction child tables after `reaction` exists.

**Tables inserted**
- `reaction_element` *(FK → `reaction.reaction_pk`)*
- `reaction_enzyme` *(FK → `reaction.reaction_pk`)*

---

#### Prerequisites
Before running Stage 6:
- **Stage 1** must be loaded: `drug(primary_drugbank_id → drug_pk)`
- **Stage 2** must be loaded: `reaction` rows exist (so `reaction_pk` exists)

---

#### How the Stage 6 code works

**1) Pre-pass (with progress bar)**
- Streams the XML once to count total `<drug>` elements
- Uses this total for an accurate progress bar in the main pass

**2) Main pass (with progress bar)**
- Streams the XML again and processes each `<drug>`
- For each drug:
  - Loads DB reactions for that drug:
    - `SELECT reaction_pk, sequence_text FROM reaction WHERE drug_pk=...`
  - Reads XML `<reactions><reaction>...</reaction></reactions>`
  - Maps each XML reaction to a DB `reaction_pk`:
    - tries matching by `sequence_text` first
    - falls back to positional mapping if needed
  - Inserts:
    - left/right sides into `reaction_element`
    - enzymes into `reaction_enzyme`

---

#### Notes
- Uses batching + commits for performance.
- Reports progress, insert counts, fallback usage, and final DB row counts.

</small>


In [28]:
# =============================================================================
# STAGE 6 (Jupyter single cell)
# Inserts:
#   33) reaction_element  (FK -> reaction.reaction_pk)
#   34) reaction_enzyme   (FK -> reaction.reaction_pk)
#
# As requested:
#   1) PRE-PASS: counts total <drug> in XML WITH a progress bar
#   2) MAIN PASS: processes XML with progress bar based on that total
#
# IMPORTANT prerequisites:
#   - Stage 1 loaded: drug(primary_drugbank_id -> drug_pk)
#   - Stage 2 loaded: reaction rows exist in DB for each drug (reaction_pk)
#
# Mapping XML reactions -> DB reaction_pk:
#   - For each drug, we fetch DB reactions: SELECT reaction_pk, sequence_text WHERE drug_pk=...
#   - We match reactions in order using sequence_text when possible, else by position.
# =============================================================================

import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, List, Tuple

from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# ---------------------------
# CONFIG (EDIT THESE)
# ---------------------------
XML_PATH = r"db/drugbank.xml"   # <-- change

MYSQL_HOST = "localhost"
MYSQL_PORT = 3306
MYSQL_USER = "root"
MYSQL_PASSWORD = ""
MYSQL_DB = "drugbank"

COMMIT_EVERY = 1000
SHOW_EVERY = 10000
# ---------------------------

def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
        password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=False
    )

def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def detect_namespace(xml_path: str) -> Optional[str]:
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns: Optional[str]):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem) -> Optional[str]:
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def bool_from_text(s: Optional[str]) -> Optional[bool]:
    if s is None:
        return None
    v = s.strip().lower()
    if v in ("true", "1", "yes"): return True
    if v in ("false", "0", "no"): return False
    return None

def get_primary_drugbank_id(drug_elem, nstag) -> Optional[str]:
    ids = drug_elem.findall(nstag("drugbank-id"))
    fallback = None
    primary = None
    for ide in ids:
        did = text(ide)
        if not did:
            continue
        if fallback is None:
            fallback = did
        if bool_from_text(ide.get("primary")):
            primary = did
            break
    return primary or fallback

def load_drug_pk_map(conn) -> Dict[str, int]:
    cur = conn.cursor()
    cur.execute("SELECT primary_drugbank_id, drug_pk FROM drug")
    mp = {row[0]: int(row[1]) for row in cur.fetchall()}
    cur.close()
    return mp

def count_table(conn, table: str) -> int:
    cur = conn.cursor()
    cur.execute(f"SELECT COUNT(*) FROM `{table}`")
    n = int(cur.fetchone()[0])
    cur.close()
    return n

def print_table(rows: List[List[Any]], headers: List[str]) -> None:
    cols = len(headers)
    widths = [len(h) for h in headers]
    for r in rows:
        for i in range(cols):
            widths[i] = max(widths[i], len(str(r[i])))

    def sep(ch="-"):
        return "+".join(ch * (w + 2) for w in widths)

    def render(r):
        return "|".join(" " + str(r[i]).ljust(widths[i]) + " " for i in range(cols))

    print(sep("-"))
    print(render(headers))
    print(sep("="))
    for r in rows:
        print(render(r))
        print(sep("-"))

def count_total_drugs_with_progress(xml_path: str) -> int:
    total = 0
    p = tqdm(total=None, unit="drug", dynamic_ncols=True, desc="Pre-pass: counting <drug>")
    try:
        ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, elem in ctx:
            if strip_ns(elem.tag) == "drug":
                total += 1
                p.update(1)
                cleanup_lxml(elem)
    finally:
        try:
            p.close()
        except Exception:
            pass
    return total

# ---------------------------
# SQL for Stage 6
# ---------------------------
SQL_DB_REACTIONS_FOR_DRUG = """
SELECT reaction_pk, sequence_text
FROM reaction
WHERE drug_pk=%s
ORDER BY reaction_pk
"""

SQL_REACTION_ELEMENT_UPSERT = """
INSERT INTO reaction_element (reaction_pk, side, drugbank_id, name)
VALUES (%s,%s,%s,%s)
ON DUPLICATE KEY UPDATE
  drugbank_id=VALUES(drugbank_id),
  name=VALUES(name)
"""

SQL_REACTION_ENZYME_INSERT = """
INSERT IGNORE INTO reaction_enzyme (reaction_pk, drugbank_id, name, uniprot_id)
VALUES (%s,%s,%s,%s)
"""

def norm_seq(s: Optional[str]) -> str:
    return (s or "").strip()

def safe_id(s: Optional[str]) -> str:
    # reaction_enzyme PK columns may be NOT NULL in MySQL if they are in a PRIMARY KEY
    return (s or "").strip()

# ---------------------------
# PRE-PASS: Count total drugs in XML with progress
# ---------------------------
print("Stage 6 pre-pass starting...")
total_drugs_xml = count_total_drugs_with_progress(XML_PATH)
print(f"Total drugs in XML: {total_drugs_xml:,}")

# ---------------------------
# MAIN PASS: Insert reaction_element + reaction_enzyme
# ---------------------------
ns = detect_namespace(XML_PATH)
nstag = nstag_factory(ns)

conn = get_conn()
cur = conn.cursor()

drug_pk_map = load_drug_pk_map(conn)
if not drug_pk_map:
    cur.close()
    conn.close()
    raise RuntimeError("drug table is empty. Run Stage 1 first.")

attempted_el = inserted_el = 0
attempted_en = inserted_en = 0
processed_drugs = 0
missing_drug = 0
missing_db_reaction = 0
unmatched_by_seq_used_positional = 0
errors = 0

t0 = time.time()
pbar = tqdm(total=total_drugs_xml, unit="drug", dynamic_ncols=True, desc="Stage 6: reactions children")

try:
    ctx = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
    for _, drug_elem in ctx:
        if strip_ns(drug_elem.tag) != "drug":
            continue

        primary_id = get_primary_drugbank_id(drug_elem, nstag)
        drug_pk = drug_pk_map.get(primary_id) if primary_id else None
        if drug_pk is None:
            missing_drug += 1
            processed_drugs += 1
            pbar.update(1)
            cleanup_lxml(drug_elem)
            continue

        try:
            # Load DB reactions for this drug
            cur.execute(SQL_DB_REACTIONS_FOR_DRUG, (drug_pk,))
            db_reactions = [(int(rpk), norm_seq(seq)) for (rpk, seq) in cur.fetchall()]

            # Parse XML reactions
            reactions = drug_elem.find(nstag("reactions"))
            if reactions is not None:
                pointer = 0  # scanning pointer into db_reactions

                for rxn in reactions.findall(nstag("reaction")):
                    xml_seq = norm_seq(text(rxn.find(nstag("sequence"))))

                    # Try to find a DB reaction with same sequence from pointer onward
                    match_idx = None
                    for j in range(pointer, len(db_reactions)):
                        if db_reactions[j][1] == xml_seq:
                            match_idx = j
                            break

                    # Fallback: positional mapping
                    if match_idx is None:
                        if pointer < len(db_reactions):
                            match_idx = pointer
                            unmatched_by_seq_used_positional += 1
                        else:
                            missing_db_reaction += 1
                            continue

                    reaction_pk = db_reactions[match_idx][0]
                    pointer = match_idx + 1

                    # reaction_element: left-element + right-element
                    left = rxn.find(nstag("left-element"))
                    if left is not None:
                        lid = text(left.find(nstag("drugbank-id")))
                        lname = text(left.find(nstag("name")))
                        cur.execute(SQL_REACTION_ELEMENT_UPSERT, (reaction_pk, "left", lid, lname))
                        attempted_el += 1
                        inserted_el += max(cur.rowcount, 0)

                    right = rxn.find(nstag("right-element"))
                    if right is not None:
                        rid = text(right.find(nstag("drugbank-id")))
                        rname = text(right.find(nstag("name")))
                        cur.execute(SQL_REACTION_ELEMENT_UPSERT, (reaction_pk, "right", rid, rname))
                        attempted_el += 1
                        inserted_el += max(cur.rowcount, 0)

                    # reaction_enzyme: list of enzymes inside reaction
                    enzs = rxn.find(nstag("enzymes"))
                    if enzs is not None:
                        for enz in enzs.findall(nstag("enzyme")):
                            eid = safe_id(text(enz.find(nstag("drugbank-id"))))
                            uname = text(enz.find(nstag("name")))
                            uniprot = safe_id(text(enz.find(nstag("uniprot-id"))))

                            # Avoid inserting totally empty identifiers (would collide)
                            if not eid and not uniprot:
                                continue

                            cur.execute(SQL_REACTION_ENZYME_INSERT, (reaction_pk, eid, uname, uniprot))
                            attempted_en += 1
                            inserted_en += max(cur.rowcount, 0)

            processed_drugs += 1
            pbar.update(1)

            if processed_drugs % COMMIT_EVERY == 0:
                conn.commit()

            if SHOW_EVERY and processed_drugs % SHOW_EVERY == 0:
                elapsed = time.time() - t0
                rate = processed_drugs / elapsed if elapsed > 0 else 0.0
                print(f"[Stage6] drugs={processed_drugs:,}/{total_drugs_xml:,} "
                      f"elem_ins={inserted_el:,} enz_ins={inserted_en:,} "
                      f"missing_db_rxn={missing_db_reaction:,} rate={rate:.2f}/sec")

        except Exception as e:
            errors += 1
            conn.rollback()
            print(f"[Stage6] ERROR drug_pk={drug_pk} ({primary_id}): {e}")

        finally:
            cleanup_lxml(drug_elem)

    conn.commit()

finally:
    try:
        pbar.close()
    except Exception:
        pass
    cur.close()
    conn.close()

elapsed = time.time() - t0

# Final DB counts
conn2 = get_conn()
rows_el = count_table(conn2, "reaction_element")
rows_en = count_table(conn2, "reaction_enzyme")
conn2.close()

print("\n=== Stage 6 Summary ===")
print_table(
    [
        ["reaction_element", f"{attempted_el:,}", f"{inserted_el:,}", f"{rows_el:,}"],
        ["reaction_enzyme", f"{attempted_en:,}", f"{inserted_en:,}", f"{rows_en:,}"],
    ],
    headers=["Table", "Attempted", "Affected(rowcount)", "Rows in DB"],
)

print("\nOther stats:")
print(f"- Drugs in XML (pre-pass):            {total_drugs_xml:,}")
print(f"- Drugs processed:                    {processed_drugs:,}")
print(f"- Missing drugs in DB (Stage 1 miss): {missing_drug:,}")
print(f"- Missing DB reaction match:          {missing_db_reaction:,}")
print(f"- Used positional reaction fallback:  {unmatched_by_seq_used_positional:,}")
print(f"- Errors:                             {errors:,}")
print(f"- Elapsed:                            {elapsed:.1f}s")


Stage 6 pre-pass starting...


Pre-pass: counting <drug>: 0drug [00:00, ?drug/s]

Total drugs in XML: 73,687


Stage 6: reactions children:   0%|                                                         | 0/73687 [00:00<?,…

[Stage6] drugs=10,000/73,687 elem_ins=64 enz_ins=9 missing_db_rxn=0 rate=2226.59/sec
[Stage6] drugs=20,000/73,687 elem_ins=1,388 enz_ins=676 missing_db_rxn=0 rate=1279.89/sec
[Stage6] drugs=30,000/73,687 elem_ins=5,006 enz_ins=2,463 missing_db_rxn=0 rate=770.06/sec
[Stage6] drugs=40,000/73,687 elem_ins=5,042 enz_ins=2,477 missing_db_rxn=0 rate=918.63/sec
[Stage6] drugs=50,000/73,687 elem_ins=5,076 enz_ins=2,496 missing_db_rxn=0 rate=1041.53/sec
[Stage6] drugs=60,000/73,687 elem_ins=5,296 enz_ins=2,660 missing_db_rxn=0 rate=1110.59/sec
[Stage6] drugs=70,000/73,687 elem_ins=7,982 enz_ins=3,701 missing_db_rxn=0 rate=765.56/sec

=== Stage 6 Summary ===
------------------+-----------+--------------------+------------
 Table            | Attempted | Affected(rowcount) | Rows in DB 
 reaction_element | 8,092     | 8,092              | 8,092      
------------------+-----------+--------------------+------------
 reaction_enzyme  | 3,737     | 3,737              | 3,737      
------------------

<small>

### Stage 7 — Interactant group (targets / enzymes / carriers / transporters)

This stage loads the **interactant dependency chain**:

- **`interactant`** *(FK → `drug`)*  
- **`interactant_action`** *(FK → `interactant`)*  
- **`polypeptide`** *(needed so `interactant_polypeptide` FK is valid)*  
- **`interactant_polypeptide`** *(FK → `interactant` + `polypeptide`)*  
- **`reference_item`** *(inserted as needed)*  
- **`interactant_reference`** *(FK → `interactant` + `reference_item`)*  

---

#### Prerequisites
Before running Stage 7, you must have:

- **Stage 1:** `drug` loaded (so `drug_pk` exists)

---

#### What the Stage 7 cell does

**1) Pre-pass: count total drugs (with progress bar)**  
- Streams the XML once to count `<drug>` elements  
- Uses this number for the main progress bar

**2) Main pass: insert interactants (with progress bar)**  
- Streams the XML again and processes each `<drug>`
- For each drug, it processes 4 interactant groups:
  - `targets/target`
  - `enzymes/enzyme`
  - `carriers/carrier`
  - `transporters/transporter`

For every interactant item it:
- **Creates or finds** a matching row in `interactant` using:
  - `drug_pk + kind + position + id + name + organism + known_action`
- Inserts actions into `interactant_action`
- Inserts polypeptides into `polypeptide` (UPSERT on `(polypeptide_id, source)`)
- Inserts links into `interactant_polypeptide`
- Inserts references:
  - creates/gets `reference_item` (deduped by composite key)
  - inserts into `interactant_reference`

---

#### Insert behavior / rerun safety
- `interactant`: **get-or-create** (find first, otherwise insert)
- `interactant_action`: `INSERT IGNORE`
- `polypeptide`: UPSERT (`ON DUPLICATE KEY UPDATE`)
- `interactant_polypeptide`: `INSERT IGNORE`
- `reference_item`: deduped using cache + DB lookup
- `interactant_reference`: `INSERT IGNORE`

---

#### Output / reporting
The stage prints:
- pre-pass drug count
- progress bar during processing
- periodic status logs (insert counts + speed)
- final summary including:
  - DB row counts for `interactant`, `interactant_action`, `polypeptide`,
    `interactant_polypeptide`, `reference_item`, and `interactant_reference`
  - missing drugs (if any)
  - runtime and error count

</small>


In [29]:
# =============================================================================
# STAGE 7 (Jupyter single cell)
# Inserts:
#   35) interactant              (FK -> drug.drug_pk)
#   36) interactant_action       (FK -> interactant.interactant_pk)
#   37) polypeptide              (PK -> (polypeptide_id, source))   [needed for FK]
#       interactant_polypeptide  (FK -> interactant + polypeptide)
#   38) reference_item (as-needed; Stage 0 prerequisite for refs) + interactant_reference
#
# As requested (like Stage 6):
#   1) PRE-PASS: counts total <drug> in XML WITH progress bar
#   2) MAIN PASS: processes XML with progress bar based on that total
#
# Idempotency (rerun-safe):
# - interactant has no unique key, so we "get-or-create" by:
#   (drug_pk, kind, position, interactant_id, name, organism, known_action)
# - actions/link tables use INSERT IGNORE (PKs prevent duplicates)
# - polypeptide uses UPSERT on PK (polypeptide_id, source)
# - reference_item is deduped by composite key lookup + cache
# =============================================================================

import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple, List

from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# ---------------------------
# CONFIG (EDIT THESE)
# ---------------------------
XML_PATH = r"db/drugbank.xml"   # <-- change

MYSQL_HOST = "localhost"
MYSQL_PORT = 3306
MYSQL_USER = "root"
MYSQL_PASSWORD = ""
MYSQL_DB = "drugbank"

COMMIT_EVERY = 500
SHOW_EVERY = 5000
# ---------------------------

def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
        password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=False
    )

def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def detect_namespace(xml_path: str) -> Optional[str]:
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns: Optional[str]):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem) -> Optional[str]:
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def bool_from_text(s: Optional[str]) -> Optional[bool]:
    if s is None:
        return None
    v = s.strip().lower()
    if v in ("true", "1", "yes"): return True
    if v in ("false", "0", "no"): return False
    return None

def get_primary_drugbank_id(drug_elem, nstag) -> Optional[str]:
    ids = drug_elem.findall(nstag("drugbank-id"))
    fallback = None
    primary = None
    for ide in ids:
        did = text(ide)
        if not did:
            continue
        if fallback is None:
            fallback = did
        if bool_from_text(ide.get("primary")):
            primary = did
            break
    return primary or fallback

def load_drug_pk_map(conn) -> Dict[str, int]:
    cur = conn.cursor()
    cur.execute("SELECT primary_drugbank_id, drug_pk FROM drug")
    mp = {row[0]: int(row[1]) for row in cur.fetchall()}
    cur.close()
    return mp

def count_table(conn, table: str) -> int:
    cur = conn.cursor()
    cur.execute(f"SELECT COUNT(*) FROM `{table}`")
    n = int(cur.fetchone()[0])
    cur.close()
    return n

def print_table(rows: List[List[Any]], headers: List[str]) -> None:
    cols = len(headers)
    widths = [len(h) for h in headers]
    for r in rows:
        for i in range(cols):
            widths[i] = max(widths[i], len(str(r[i])))

    def sep(ch="-"):
        return "+".join(ch * (w + 2) for w in widths)

    def render(r):
        return "|".join(" " + str(r[i]).ljust(widths[i]) + " " for i in range(cols))

    print(sep("-"))
    print(render(headers))
    print(sep("="))
    for r in rows:
        print(render(r))
        print(sep("-"))

def count_total_drugs_with_progress(xml_path: str) -> int:
    total = 0
    p = tqdm(total=None, unit="drug", dynamic_ncols=True, desc="Pre-pass: counting <drug>")
    try:
        ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, elem in ctx:
            if strip_ns(elem.tag) == "drug":
                total += 1
                p.update(1)
                cleanup_lxml(elem)
    finally:
        try:
            p.close()
        except Exception:
            pass
    return total

# ---------------------------
# Reference helpers (same idea as Stage 4)
# ---------------------------
SQL_REF_SELECT = """
SELECT reference_pk
FROM reference_item
WHERE ref_type <=> %s
  AND ref_id <=> %s
  AND pubmed_id <=> %s
  AND citation <=> %s
  AND isbn <=> %s
  AND title <=> %s
  AND url <=> %s
LIMIT 1
"""
SQL_REF_INSERT = """
INSERT INTO reference_item
  (ref_type, ref_id, pubmed_id, citation, isbn, title, url)
VALUES (%s,%s,%s,%s,%s,%s,%s)
"""
def parse_reference_list(refs_node, nstag):
    out = []
    arts = refs_node.find(nstag("articles"))
    if arts is not None:
        for a in arts.findall(nstag("article")):
            out.append(("article", text(a.find(nstag("ref-id"))), text(a.find(nstag("pubmed-id"))),
                        text(a.find(nstag("citation"))), None, None, None))
    tbs = refs_node.find(nstag("textbooks"))
    if tbs is not None:
        for t in tbs.findall(nstag("textbook")):
            out.append(("textbook", text(t.find(nstag("ref-id"))), None, text(t.find(nstag("citation"))),
                        text(t.find(nstag("isbn"))), None, None))
    links = refs_node.find(nstag("links"))
    if links is not None:
        for l in links.findall(nstag("link")):
            out.append(("link", text(l.find(nstag("ref-id"))), None, None, None,
                        text(l.find(nstag("title"))), text(l.find(nstag("url")))))
    atts = refs_node.find(nstag("attachments"))
    if atts is not None:
        for a in atts.findall(nstag("attachment")):
            out.append(("attachment", text(a.find(nstag("ref-id"))), None, None, None,
                        text(a.find(nstag("title"))), text(a.find(nstag("url")))))
    return out

def get_or_create_reference_pk(cur, cache: Dict[Tuple, int], ref_tuple: Tuple) -> int:
    if ref_tuple in cache:
        return cache[ref_tuple]
    cur.execute(SQL_REF_SELECT, ref_tuple)
    row = cur.fetchone()
    if row:
        pk = int(row[0])
        cache[ref_tuple] = pk
        return pk
    cur.execute(SQL_REF_INSERT, ref_tuple)
    pk = int(cur.lastrowid)
    cache[ref_tuple] = pk
    return pk

# ---------------------------
# Interactant + polypeptide SQL
# ---------------------------
SQL_INTERACTANT_FIND = """
SELECT interactant_pk
FROM interactant
WHERE drug_pk=%s
  AND kind <=> %s
  AND position <=> %s
  AND interactant_id <=> %s
  AND name <=> %s
  AND organism <=> %s
  AND known_action <=> %s
ORDER BY interactant_pk DESC
LIMIT 1
"""
SQL_INTERACTANT_INSERT = """
INSERT INTO interactant
  (drug_pk, kind, position, interactant_id, name, organism, known_action)
VALUES (%s,%s,%s,%s,%s,%s,%s)
"""

SQL_INTERACTANT_ACTION = "INSERT IGNORE INTO interactant_action (interactant_pk, action) VALUES (%s,%s)"
SQL_INTERACTANT_POLY = "INSERT IGNORE INTO interactant_polypeptide (interactant_pk, polypeptide_id, source) VALUES (%s,%s,%s)"
SQL_INTERACTANT_REF = "INSERT IGNORE INTO interactant_reference (interactant_pk, reference_pk) VALUES (%s,%s)"

SQL_POLYPEPTIDE_UPSERT = """
INSERT INTO polypeptide (
  polypeptide_id, source, name, general_function, specific_function, gene_name, locus,
  cellular_location, transmembrane_regions, signal_regions, theoretical_pi, molecular_weight,
  chromosome_location, organism_name, organism_ncbi_taxonomy_id,
  amino_acid_sequence, gene_sequence
)
VALUES (%s,%s,%s,%s,%s,%s,%s,
        %s,%s,%s,%s,%s,
        %s,%s,%s,
        %s,%s)
ON DUPLICATE KEY UPDATE
  name=VALUES(name),
  general_function=VALUES(general_function),
  specific_function=VALUES(specific_function),
  gene_name=VALUES(gene_name),
  locus=VALUES(locus),
  cellular_location=VALUES(cellular_location),
  transmembrane_regions=VALUES(transmembrane_regions),
  signal_regions=VALUES(signal_regions),
  theoretical_pi=VALUES(theoretical_pi),
  molecular_weight=VALUES(molecular_weight),
  chromosome_location=VALUES(chromosome_location),
  organism_name=VALUES(organism_name),
  organism_ncbi_taxonomy_id=VALUES(organism_ncbi_taxonomy_id),
  amino_acid_sequence=VALUES(amino_acid_sequence),
  gene_sequence=VALUES(gene_sequence)
"""

def get_or_create_interactant_pk(cur, drug_pk: int, kind: str, position: Optional[int],
                                interactant_id: Optional[str], name: Optional[str],
                                organism: Optional[str], known_action: Optional[str]) -> int:
    cur.execute(SQL_INTERACTANT_FIND, (drug_pk, kind, position, interactant_id, name, organism, known_action))
    row = cur.fetchone()
    if row:
        return int(row[0])
    cur.execute(SQL_INTERACTANT_INSERT, (drug_pk, kind, position, interactant_id, name, organism, known_action))
    return int(cur.lastrowid)

def safe_int(s: Optional[str]) -> Optional[int]:
    if s is None:
        return None
    try:
        return int(s)
    except:
        return None

# ---------------------------
# RUN STAGE 7
# ---------------------------
print("Stage 7 pre-pass starting...")
total_drugs_xml = count_total_drugs_with_progress(XML_PATH)
print(f"Total drugs in XML: {total_drugs_xml:,}")

ns = detect_namespace(XML_PATH)
nstag = nstag_factory(ns)

conn = get_conn()
cur = conn.cursor()

drug_pk_map = load_drug_pk_map(conn)
if not drug_pk_map:
    cur.close()
    conn.close()
    raise RuntimeError("drug table is empty. Run Stage 1 first.")

# Counters
attempt_i = ins_i = 0
attempt_a = ins_a = 0
attempt_poly = ins_poly = 0
attempt_link = ins_link = 0
attempt_ref = ins_ref = 0

processed_drugs = 0
missing_drug = 0
errors = 0

ref_cache: Dict[Tuple, int] = {}

t0 = time.time()
pbar = tqdm(total=total_drugs_xml, unit="drug", dynamic_ncols=True, desc="Stage 7: interactants")

try:
    ctx = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
    for _, drug_elem in ctx:
        if strip_ns(drug_elem.tag) != "drug":
            continue

        primary_id = get_primary_drugbank_id(drug_elem, nstag)
        drug_pk = drug_pk_map.get(primary_id) if primary_id else None
        if drug_pk is None:
            missing_drug += 1
            processed_drugs += 1
            pbar.update(1)
            cleanup_lxml(drug_elem)
            continue

        try:
            groups = [
                ("target", "targets", "target"),
                ("enzyme", "enzymes", "enzyme"),
                ("carrier", "carriers", "carrier"),
                ("transporter", "transporters", "transporter"),
            ]

            for kind, container_tag, item_tag in groups:
                container = drug_elem.find(nstag(container_tag))
                if container is None:
                    continue

                for item in container.findall(nstag(item_tag)):
                    pos_i = safe_int(item.get("position"))
                    interactant_id = text(item.find(nstag("id")))
                    name = text(item.find(nstag("name")))
                    organism = text(item.find(nstag("organism")))
                    known_action = text(item.find(nstag("known-action")))

                    # interactant get-or-create
                    interactant_pk = get_or_create_interactant_pk(cur, drug_pk, kind, pos_i, interactant_id, name, organism, known_action)
                    attempt_i += 1
                    ins_i += 1  # created or found (count as processed)

                    # actions
                    actions = item.find(nstag("actions"))
                    if actions is not None:
                        for act in actions.findall(nstag("action")):
                            v = text(act)
                            if not v:
                                continue
                            cur.execute(SQL_INTERACTANT_ACTION, (interactant_pk, v))
                            attempt_a += 1
                            ins_a += max(cur.rowcount, 0)

                    # references -> reference_item + interactant_reference
                    refs_node = item.find(nstag("references"))
                    if refs_node is not None:
                        for ref_tuple in parse_reference_list(refs_node, nstag):
                            reference_pk = get_or_create_reference_pk(cur, ref_cache, ref_tuple)
                            attempt_ref += 1
                            # ref table is upsert-ish; rowcount for SELECT hits is not tracked here
                            # so we count only link insert below
                            cur.execute(SQL_INTERACTANT_REF, (interactant_pk, reference_pk))
                            attempt_link += 1
                            ins_link += max(cur.rowcount, 0)

                    # polypeptides (ensure polypeptide exists, then link)
                    for poly in item.findall(nstag("polypeptide")):
                        poly_id = poly.get("id")
                        poly_source = poly.get("source")
                        if not poly_id or not poly_source:
                            continue

                        org_elem = poly.find(nstag("organism"))
                        org_name = text(org_elem)
                        org_tax = org_elem.get("ncbi-taxonomy-id") if org_elem is not None else None

                        aa_seq = text(poly.find(nstag("amino-acid-sequence")))
                        gene_seq = text(poly.find(nstag("gene-sequence")))

                        cur.execute(SQL_POLYPEPTIDE_UPSERT, (
                            poly_id, poly_source,
                            text(poly.find(nstag("name"))),
                            text(poly.find(nstag("general-function"))),
                            text(poly.find(nstag("specific-function"))),
                            text(poly.find(nstag("gene-name"))),
                            text(poly.find(nstag("locus"))),
                            text(poly.find(nstag("cellular-location"))),
                            text(poly.find(nstag("transmembrane-regions"))),
                            text(poly.find(nstag("signal-regions"))),
                            text(poly.find(nstag("theoretical-pi"))),
                            text(poly.find(nstag("molecular-weight"))),
                            text(poly.find(nstag("chromosome-location"))),
                            org_name, org_tax,
                            aa_seq, gene_seq
                        ))
                        attempt_poly += 1
                        ins_poly += max(cur.rowcount, 0)

                        cur.execute(SQL_INTERACTANT_POLY, (interactant_pk, poly_id, poly_source))
                        attempt_link += 1
                        ins_link += max(cur.rowcount, 0)

            processed_drugs += 1
            pbar.update(1)

            if processed_drugs % COMMIT_EVERY == 0:
                conn.commit()

            if SHOW_EVERY and processed_drugs % SHOW_EVERY == 0:
                elapsed = time.time() - t0
                rate = processed_drugs / elapsed if elapsed > 0 else 0.0
                print(f"[Stage7] drugs={processed_drugs:,}/{total_drugs_xml:,} "
                      f"interactant_ops={attempt_i:,} action_ins={ins_a:,} "
                      f"polypeptide_ops={attempt_poly:,} link_ins={ins_link:,} rate={rate:.2f}/sec")

        except Exception as e:
            errors += 1
            conn.rollback()
            print(f"[Stage7] ERROR drug_pk={drug_pk} ({primary_id}): {e}")

        finally:
            cleanup_lxml(drug_elem)

    conn.commit()

finally:
    try:
        pbar.close()
    except Exception:
        pass
    cur.close()
    conn.close()

elapsed = time.time() - t0

# DB counts
conn2 = get_conn()
rows_interactant = count_table(conn2, "interactant")
rows_action = count_table(conn2, "interactant_action")
rows_poly = count_table(conn2, "polypeptide")
rows_link_poly = count_table(conn2, "interactant_polypeptide")
rows_refitem = count_table(conn2, "reference_item")
rows_intref = count_table(conn2, "interactant_reference")
conn2.close()

print("\n=== Stage 7 Summary ===")
print_table(
    [
        ["interactant", f"{attempt_i:,}", f"{rows_interactant:,}"],
        ["interactant_action", f"{attempt_a:,}", f"{rows_action:,}"],
        ["polypeptide", f"{attempt_poly:,}", f"{rows_poly:,}"],
        ["interactant_polypeptide", f"{attempt_link:,}", f"{rows_link_poly:,}"],
        ["reference_item (total)", "-", f"{rows_refitem:,}"],
        ["interactant_reference", f"{attempt_link:,}", f"{rows_intref:,}"],
    ],
    headers=["Table", "Attempted", "Rows in DB"],
)

print("\nOther stats:")
print(f"- Drugs in XML (pre-pass):            {total_drugs_xml:,}")
print(f"- Drugs processed:                    {processed_drugs:,}")
print(f"- Missing drugs in DB:                {missing_drug:,}")
print(f"- Errors:                             {errors:,}")
print(f"- Elapsed:                            {elapsed:.1f}s")


Stage 7 pre-pass starting...


Pre-pass: counting <drug>: 0drug [00:00, ?drug/s]

Total drugs in XML: 73,687


Stage 7: interactants:   0%|                                                               | 0/73687 [00:00<?,…

[Stage7] drugs=5,000/73,687 interactant_ops=609 action_ins=539 polypeptide_ops=594 link_ins=2,374 rate=967.33/sec
[Stage7] drugs=10,000/73,687 interactant_ops=900 action_ins=681 polypeptide_ops=900 link_ins=3,393 rate=1539.94/sec
[Stage7] drugs=15,000/73,687 interactant_ops=1,289 action_ins=875 polypeptide_ops=1,314 link_ins=4,713 rate=1728.25/sec
[Stage7] drugs=20,000/73,687 interactant_ops=4,045 action_ins=3,709 polypeptide_ops=4,696 link_ins=14,954 rate=745.99/sec
[Stage7] drugs=25,000/73,687 interactant_ops=10,266 action_ins=10,143 polypeptide_ops=12,059 link_ins=37,557 rate=364.43/sec
[Stage7] drugs=30,000/73,687 interactant_ops=12,140 action_ins=11,714 polypeptide_ops=14,311 link_ins=43,829 rate=368.96/sec
[Stage7] drugs=35,000/73,687 interactant_ops=13,059 action_ins=11,951 polypeptide_ops=15,256 link_ins=46,068 rate=412.21/sec
[Stage7] drugs=40,000/73,687 interactant_ops=14,233 action_ins=12,286 polypeptide_ops=16,464 link_ins=48,911 rate=447.74/sec
[Stage7] drugs=45,000/73,687

<small>

### Stage 8 — Polypeptide children (`polypeptide_external_identifier`, `polypeptide_synonym`)

This stage loads the **child tables of `polypeptide`**:

- **`polypeptide_external_identifier`** *(FK → `polypeptide(polypeptide_id, source)`)*
- **`polypeptide_synonym`** *(FK → `polypeptide(polypeptide_id, source)`)*
  
---

#### Prerequisites
Before running Stage 8, you must have:

- **Stage 7:** `polypeptide` loaded (so `(polypeptide_id, source)` exists)

---

#### What the Stage 8 cell does

**1) Pre-pass: count total drugs (with progress bar)**  
- Streams the XML once to count `<drug>` elements  
- Uses this number for the main progress bar

**2) Main pass: insert polypeptide children (with progress bar)**  
- Streams the XML again and processes each `<drug>`
- Searches for polypeptides under the interactant groups:
  - `targets/target/polypeptide`
  - `enzymes/enzyme/polypeptide`
  - `carriers/carrier/polypeptide`
  - `transporters/transporter/polypeptide`

For each `<polypeptide id="..." source="...">` it:
- Verifies the polypeptide exists in the DB (FK safety)
- Inserts:
  - external identifiers into `polypeptide_external_identifier`
  - synonyms into `polypeptide_synonym`

---

#### Insert behavior / rerun safety
- Uses **`INSERT IGNORE`** for both tables  
  → safe to re-run without duplicating rows.

---

#### Output / reporting
The stage prints:
- pre-pass drug count
- progress bar during insertion
- periodic status logs (insert counts + speed)
- final summary table showing:
  - attempted statements
  - affected rows (`rowcount`)
  - total rows currently in DB for each Stage 8 table

It also reports:
- **missing polypeptide FK matches** (polypeptides found in XML but not present in `polypeptide` table)

</small>


In [30]:
# =============================================================================
# STAGE 8 (Jupyter single cell)
# Inserts:
#   39) polypeptide_external_identifier (FK -> polypeptide(polypeptide_id, source))
#   40) polypeptide_synonym            (FK -> polypeptide(polypeptide_id, source))
#
# Like Stage 7:
#   1) PRE-PASS: counts total <drug> in XML WITH progress bar
#   2) MAIN PASS: processes XML with progress bar based on that total
#
# Prerequisite:
#   - Stage 7 must have loaded polypeptide rows (polypeptide table)
#
# Rerun safety:
#   - Uses INSERT IGNORE into polypeptide_external_identifier and polypeptide_synonym
# =============================================================================

import time
from dataclasses import dataclass
from typing import Optional, Dict, Any, List, Tuple, Set

from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# ---------------------------
# CONFIG (EDIT THESE)
# ---------------------------
XML_PATH = r"db/drugbank.xml"   # <-- change

MYSQL_HOST = "localhost"
MYSQL_PORT = 3306
MYSQL_USER = "root"
MYSQL_PASSWORD = ""
MYSQL_DB = "drugbank"

COMMIT_EVERY = 1000
SHOW_EVERY = 5000
# ---------------------------

def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
        password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=False
    )

def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def detect_namespace(xml_path: str) -> Optional[str]:
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns: Optional[str]):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem) -> Optional[str]:
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def bool_from_text(s: Optional[str]) -> Optional[bool]:
    if s is None:
        return None
    v = s.strip().lower()
    if v in ("true", "1", "yes"): return True
    if v in ("false", "0", "no"): return False
    return None

def get_primary_drugbank_id(drug_elem, nstag) -> Optional[str]:
    ids = drug_elem.findall(nstag("drugbank-id"))
    fallback = None
    primary = None
    for ide in ids:
        did = text(ide)
        if not did:
            continue
        if fallback is None:
            fallback = did
        if bool_from_text(ide.get("primary")):
            primary = did
            break
    return primary or fallback

def count_table(conn, table: str) -> int:
    cur = conn.cursor()
    cur.execute(f"SELECT COUNT(*) FROM `{table}`")
    n = int(cur.fetchone()[0])
    cur.close()
    return n

def print_table(rows: List[List[Any]], headers: List[str]) -> None:
    cols = len(headers)
    widths = [len(h) for h in headers]
    for r in rows:
        for i in range(cols):
            widths[i] = max(widths[i], len(str(r[i])))

    def sep(ch="-"):
        return "+".join(ch * (w + 2) for w in widths)

    def render(r):
        return "|".join(" " + str(r[i]).ljust(widths[i]) + " " for i in range(cols))

    print(sep("-"))
    print(render(headers))
    print(sep("="))
    for r in rows:
        print(render(r))
        print(sep("-"))

def count_total_drugs_with_progress(xml_path: str) -> int:
    total = 0
    p = tqdm(total=None, unit="drug", dynamic_ncols=True, desc="Pre-pass: counting <drug>")
    try:
        ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, elem in ctx:
            if strip_ns(elem.tag) == "drug":
                total += 1
                p.update(1)
                cleanup_lxml(elem)
    finally:
        try:
            p.close()
        except Exception:
            pass
    return total

def load_polypeptide_pk_set(conn) -> Set[Tuple[str, str]]:
    """
    Loads all existing (polypeptide_id, source) keys from DB to validate FK existence.
    This can be large; if memory becomes an issue, set this to empty and rely on FK checks.
    """
    cur = conn.cursor()
    cur.execute("SELECT polypeptide_id, source FROM polypeptide")
    s = {(row[0], row[1]) for row in cur.fetchall()}
    cur.close()
    return s

# ---------------------------
# SQL for Stage 8
# ---------------------------
SQL_POLY_EXT_ID = """
INSERT IGNORE INTO polypeptide_external_identifier (polypeptide_id, source, resource, identifier)
VALUES (%s,%s,%s,%s)
"""

SQL_POLY_SYN = """
INSERT IGNORE INTO polypeptide_synonym (polypeptide_id, source, synonym)
VALUES (%s,%s,%s)
"""

# ---------------------------
# RUN STAGE 8
# ---------------------------
print("Stage 8 pre-pass starting...")
total_drugs_xml = count_total_drugs_with_progress(XML_PATH)
print(f"Total drugs in XML: {total_drugs_xml:,}")

ns = detect_namespace(XML_PATH)
nstag = nstag_factory(ns)

conn = get_conn()
cur = conn.cursor()

# Load polypeptide keys (FK existence check)
poly_keys = load_polypeptide_pk_set(conn)
if not poly_keys:
    cur.close()
    conn.close()
    raise RuntimeError("polypeptide table is empty. Run Stage 7 first.")

attempt_ext = ins_ext = 0
attempt_syn = ins_syn = 0
processed_drugs = 0
missing_poly_fk = 0
errors = 0

t0 = time.time()
pbar = tqdm(total=total_drugs_xml, unit="drug", dynamic_ncols=True, desc="Stage 8: polypeptide children")

try:
    ctx = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
    for _, drug_elem in ctx:
        if strip_ns(drug_elem.tag) != "drug":
            continue

        try:
            # Stage 8 data lives inside targets/enzymes/carriers/transporters -> polypeptide
            groups = [
                ("targets", "target"),
                ("enzymes", "enzyme"),
                ("carriers", "carrier"),
                ("transporters", "transporter"),
            ]

            for container_tag, item_tag in groups:
                container = drug_elem.find(nstag(container_tag))
                if container is None:
                    continue

                for item in container.findall(nstag(item_tag)):
                    for poly in item.findall(nstag("polypeptide")):
                        poly_id = poly.get("id")
                        poly_source = poly.get("source")
                        if not poly_id or not poly_source:
                            continue

                        # verify polypeptide exists in DB (avoid FK error)
                        if (poly_id, poly_source) not in poly_keys:
                            missing_poly_fk += 1
                            continue

                        # external identifiers
                        ex_ids = poly.find(nstag("external-identifiers"))
                        if ex_ids is not None:
                            for exi in ex_ids.findall(nstag("external-identifier")):
                                resource = text(exi.find(nstag("resource")))
                                identifier = text(exi.find(nstag("identifier")))
                                if not resource or not identifier:
                                    continue
                                cur.execute(SQL_POLY_EXT_ID, (poly_id, poly_source, resource, identifier))
                                attempt_ext += 1
                                ins_ext += max(cur.rowcount, 0)

                        # synonyms
                        syns = poly.find(nstag("synonyms"))
                        if syns is not None:
                            for s in syns.findall(nstag("synonym")):
                                sv = text(s)
                                if not sv:
                                    continue
                                cur.execute(SQL_POLY_SYN, (poly_id, poly_source, sv))
                                attempt_syn += 1
                                ins_syn += max(cur.rowcount, 0)

            processed_drugs += 1
            pbar.update(1)

            if processed_drugs % COMMIT_EVERY == 0:
                conn.commit()

            if SHOW_EVERY and processed_drugs % SHOW_EVERY == 0:
                elapsed = time.time() - t0
                rate = processed_drugs / elapsed if elapsed > 0 else 0.0
                print(f"[Stage8] drugs={processed_drugs:,}/{total_drugs_xml:,} "
                      f"ext_ins={ins_ext:,} syn_ins={ins_syn:,} "
                      f"missing_poly_fk={missing_poly_fk:,} rate={rate:.2f}/sec")

        except Exception as e:
            errors += 1
            conn.rollback()
            print(f"[Stage8] ERROR: {e}")

        finally:
            cleanup_lxml(drug_elem)

    conn.commit()

finally:
    try:
        pbar.close()
    except Exception:
        pass
    cur.close()
    conn.close()

elapsed = time.time() - t0

# DB counts
conn2 = get_conn()
rows_ext = count_table(conn2, "polypeptide_external_identifier")
rows_syn = count_table(conn2, "polypeptide_synonym")
rows_poly = count_table(conn2, "polypeptide")
conn2.close()

print("\n=== Stage 8 Summary ===")
print_table(
    [
        ["polypeptide_external_identifier", f"{attempt_ext:,}", f"{ins_ext:,}", f"{rows_ext:,}"],
        ["polypeptide_synonym", f"{attempt_syn:,}", f"{ins_syn:,}", f"{rows_syn:,}"],
        ["polypeptide (reference)", "-", "-", f"{rows_poly:,}"],
    ],
    headers=["Table", "Attempted", "Affected(rowcount)", "Rows in DB"],
)

print("\nOther stats:")
print(f"- Drugs in XML (pre-pass):          {total_drugs_xml:,}")
print(f"- Drugs processed:                  {processed_drugs:,}")
print(f"- Missing polypeptide FK matches:   {missing_poly_fk:,}")
print(f"- Errors:                           {errors:,}")
print(f"- Elapsed:                          {elapsed:.1f}s")


Stage 8 pre-pass starting...


Pre-pass: counting <drug>: 0drug [00:00, ?drug/s]

Total drugs in XML: 73,687


Stage 8: polypeptide children:   0%|                                                       | 0/73687 [00:00<?,…

[Stage8] drugs=5,000/73,687 ext_ins=2,684 syn_ins=1,938 missing_poly_fk=0 rate=1115.16/sec
[Stage8] drugs=10,000/73,687 ext_ins=4,182 syn_ins=3,091 missing_poly_fk=0 rate=1691.20/sec
[Stage8] drugs=15,000/73,687 ext_ins=6,095 syn_ins=4,520 missing_poly_fk=0 rate=1939.74/sec
[Stage8] drugs=20,000/73,687 ext_ins=9,821 syn_ins=7,289 missing_poly_fk=0 rate=758.86/sec
[Stage8] drugs=25,000/73,687 ext_ins=12,790 syn_ins=9,958 missing_poly_fk=0 rate=349.43/sec
[Stage8] drugs=30,000/73,687 ext_ins=14,278 syn_ins=11,384 missing_poly_fk=0 rate=362.60/sec
[Stage8] drugs=35,000/73,687 ext_ins=16,099 syn_ins=12,767 missing_poly_fk=0 rate=407.32/sec
[Stage8] drugs=40,000/73,687 ext_ins=17,694 syn_ins=14,030 missing_poly_fk=0 rate=440.83/sec
[Stage8] drugs=45,000/73,687 ext_ins=18,629 syn_ins=14,793 missing_poly_fk=0 rate=481.14/sec
[Stage8] drugs=50,000/73,687 ext_ins=19,682 syn_ins=15,560 missing_poly_fk=0 rate=506.72/sec
[Stage8] drugs=55,000/73,687 ext_ins=20,651 syn_ins=16,298 missing_poly_fk=0 

<small>

### Stage 9 — Optional raw storage (`drug_raw`)

This stage stores the **entire raw `<drug>` XML record** into the database:

- **`drug_raw`** *(FK → `drug.drug_pk`)*  
  This should be the **last step per drug**, after all structured tables are loaded.

---

#### Prerequisites
Before running Stage 9, you must have:

- **Stage 1:** `drug` loaded (so `drug_pk` exists)

---

#### What the Stage 9 cell does

**1) Pre-pass: count total drugs (with progress bar)**  
- Streams the XML once to count `<drug>` elements  
- Uses that total for the main progress bar

**2) Main pass: insert/update `drug_raw`**  
- Streams the XML again and for each `<drug>`:
  - resolves `drug_pk` using `primary_drugbank_id`
  - converts the XML subtree into a **JSON-serializable dict** (tag/attrs/text/children)
  - writes it into `drug_raw.raw_json` using:
    - `INSERT ... ON DUPLICATE KEY UPDATE`  
      (safe to re-run; it updates the stored raw JSON)

---

#### MariaDB / MySQL notes
- The loader stores JSON as **TEXT/LONGTEXT** (no `CAST(... AS JSON)`).
- Recommended schema:
  - `ALTER TABLE drug_raw MODIFY raw_json LONGTEXT NOT NULL;`

If you hit `max_allowed_packet` issues (very large raw records), enable optional compression:
- set `COMPRESS_RAW = True`
- data will be stored as a string with prefix: `zlib+b64:...`

---

#### Output / reporting
The stage prints:
- pre-pass drug count
- progress bar during insertion
- periodic status logs (writes + speed)
- final summary table:
  - attempted writes
  - affected rows (`rowcount`)
  - total rows in `drug_raw`
  - runtime and error stats

</small>


In [3]:
import time, base64, zlib
from lxml import etree as ET
from tqdm.notebook import tqdm
import mysql.connector

# ---------------------------
# CONFIG (EDIT THESE)
# ---------------------------
XML_PATH = r"db/drugbank.xml"

MYSQL_HOST="localhost"
MYSQL_PORT=3306
MYSQL_USER="root"
MYSQL_PASSWORD=""
MYSQL_DB="drugbank"

COMMIT_EVERY = 200
SHOW_EVERY = 2000

# Use your new 1GB setting as a SESSION value too (important for existing sessions)
SESSION_MAX_PACKET = 1073741824  # 1GB

# Strongly recommended for Stage 9:
COMPRESS_RAW = True
# ---------------------------

SQL_DRUG_RAW_UPSERT = """
INSERT INTO drug_raw (drug_pk, raw_json)
VALUES (%s, %s)
ON DUPLICATE KEY UPDATE raw_json=VALUES(raw_json)
"""

def open_conn():
    conn = mysql.connector.connect(
        host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
        password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=False
    )
    cur = conn.cursor()
    # Ensure THIS session uses large packet size
    try:
        cur.execute(f"SET SESSION max_allowed_packet={SESSION_MAX_PACKET}")
    except Exception:
        pass
    # Print session packet once (optional)
    cur.execute("SHOW VARIABLES LIKE 'max_allowed_packet'")
    print("[Session]", cur.fetchone())
    return conn, cur

def strip_ns(tag: str) -> str:
    return tag.split("}", 1)[1] if "}" in tag else tag

def cleanup_lxml(elem):
    elem.clear()
    while elem.getprevious() is not None:
        del elem.getparent()[0]

def detect_namespace(xml_path: str):
    ctx = ET.iterparse(xml_path, events=("start",))
    _, root = next(ctx)
    if root.tag.startswith("{"):
        return root.tag.split("}")[0].strip("{")
    return None

def nstag_factory(ns):
    def nstag(local: str) -> str:
        return f"{{{ns}}}{local}" if ns else local
    return nstag

def text(elem):
    if elem is None or elem.text is None:
        return None
    t = elem.text.strip()
    return t if t else None

def pick_primary_id(drug_elem, nstag):
    """
    Robust primary id selection:
    - if @primary exists and is not false/0/no/n/empty, pick that
    - else pick first non-empty drugbank-id
    """
    ids = drug_elem.findall(nstag("drugbank-id"))
    first = None
    for ide in ids:
        did = text(ide)
        if not did:
            continue
        if first is None:
            first = did
        prim = ide.get("primary")
        if prim is not None:
            v = prim.strip().lower()
            if v not in ("false", "0", "no", "n", ""):
                return did
    return first

def load_drug_pk_map(cur):
    cur.execute("SELECT primary_drugbank_id, drug_pk FROM drug")
    return {row[0]: int(row[1]) for row in cur.fetchall()}

def count_total_drugs_with_progress(xml_path: str) -> int:
    total = 0
    p = tqdm(total=None, unit="drug", dynamic_ncols=True, desc="Pre-pass: counting <drug>")
    try:
        ctx = ET.iterparse(xml_path, events=("end",), huge_tree=True)
        for _, e in ctx:
            if strip_ns(e.tag) == "drug":
                total += 1
                p.update(1)
                cleanup_lxml(e)
    finally:
        try: p.close()
        except: pass
    return total

def encode_raw(drug_elem) -> str:
    """
    Store raw <drug> XML. Much smaller than dict->json.
    """
    raw_bytes = ET.tostring(drug_elem, encoding="utf-8", with_tail=False)
    if not COMPRESS_RAW:
        return raw_bytes.decode("utf-8", errors="replace")
    comp = zlib.compress(raw_bytes, level=6)
    return "zlib+b64:" + base64.b64encode(comp).decode("ascii")

def count_table(cur, table: str) -> int:
    cur.execute(f"SELECT COUNT(*) FROM `{table}`")
    return int(cur.fetchone()[0])

def reconnect(conn, cur):
    try:
        cur.close()
    except: 
        pass
    try:
        conn.close()
    except:
        pass
    return open_conn()

# ---------------------------
# RUN STAGE 9
# ---------------------------
print("Stage 9 pre-pass starting...")
total_drugs_xml = count_total_drugs_with_progress(XML_PATH)
print(f"Total drugs in XML: {total_drugs_xml:,}")

ns = detect_namespace(XML_PATH)
nstag = nstag_factory(ns)

conn, cur = open_conn()

drug_pk_map = load_drug_pk_map(cur)
if not drug_pk_map:
    raise RuntimeError("drug table is empty. Run Stage 1 first.")

attempted = 0
affected = 0
processed = 0
missing_drug = 0
too_large = 0
errors = 0

t0 = time.time()
pbar = tqdm(total=total_drugs_xml, unit="drug", dynamic_ncols=True, desc="Stage 9: drug_raw")

try:
    ctx = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
    for _, drug_elem in ctx:
        if strip_ns(drug_elem.tag) != "drug":
            continue

        primary_id = pick_primary_id(drug_elem, nstag)
        drug_pk = drug_pk_map.get(primary_id) if primary_id else None

        if drug_pk is None:
            missing_drug += 1
            processed += 1
            pbar.update(1)
            cleanup_lxml(drug_elem)
            continue

        try:
            payload = encode_raw(drug_elem)
            cur.execute(SQL_DRUG_RAW_UPSERT, (drug_pk, payload))
            attempted += 1
            affected += max(cur.rowcount, 0)

        except mysql.connector.Error as e:
            err = getattr(e, "errno", None)

            # packet too large
            if err == 1153:
                too_large += 1
                try:
                    conn.rollback()
                except:
                    pass
                # If not compressing, this will happen. If compressing and still happens, skip this record.
                print(f"[Stage9] SKIP drug_pk={drug_pk} (packet too large). Consider increasing max_allowed_packet further.")
            # connection lost -> reconnect and continue
            elif err in (2006, 2013, 2055):
                errors += 1
                conn, cur = reconnect(conn, cur)
                drug_pk_map = load_drug_pk_map(cur)  # reload map after reconnect
                print(f"[Stage9] Reconnected after connection drop (errno={err}). Continuing...")
            else:
                errors += 1
                try:
                    conn.rollback()
                except:
                    pass
                print(f"[Stage9] ERROR drug_pk={drug_pk} errno={err}: {e}")

        finally:
            processed += 1
            pbar.update(1)

            if processed % COMMIT_EVERY == 0:
                try:
                    conn.commit()
                except mysql.connector.Error:
                    conn, cur = reconnect(conn, cur)
                    drug_pk_map = load_drug_pk_map(cur)

            if SHOW_EVERY and processed % SHOW_EVERY == 0:
                elapsed = time.time() - t0
                rate = processed / elapsed if elapsed > 0 else 0.0
                print(f"[Stage9] drugs={processed:,}/{total_drugs_xml:,} writes={attempted:,} affected={affected:,} "
                      f"missing_drug={missing_drug:,} too_large={too_large:,} errors={errors:,} rate={rate:.2f}/sec")

            cleanup_lxml(drug_elem)

    try:
        conn.commit()
    except:
        pass

finally:
    try: pbar.close()
    except: pass
    try: cur.close()
    except: pass
    try: conn.close()
    except: pass

# Final summary
conn3, cur3 = open_conn()
rows_raw = count_table(cur3, "drug_raw")
cur3.close(); conn3.close()

elapsed = time.time() - t0
print("\n=== Stage 9 Summary ===")
print(f"Processed: {processed:,}/{total_drugs_xml:,}")
print(f"Writes attempted: {attempted:,} | affected(rowcount): {affected:,}")
print(f"Rows in drug_raw: {rows_raw:,}")
print(f"Missing drug: {missing_drug:,} | Too large skipped: {too_large:,} | Errors: {errors:,}")
print(f"Elapsed: {elapsed:.1f}s | COMPRESS_RAW={COMPRESS_RAW}")


Stage 9 pre-pass starting...


Pre-pass: counting <drug>: 0drug [00:00, ?drug/s]

KeyboardInterrupt: 

<small>

### Stage 9 — Optional raw storage (`drug_raw`)

This stage stores the **entire raw `<drug>` XML record** into the database:

- **`drug_raw`** *(FK → `drug.drug_pk`)*  
  This should be the **last step per drug**, after all structured tables are loaded.

---

#### Prerequisites
Before running Stage 9, you must have:

- **Stage 1:** `drug` loaded (so `drug_pk` exists)

---

#### What the Stage 9 cell does

**1) Pre-pass: count total drugs (with progress bar)**  
- Streams the XML once to count `<drug>` elements  
- Uses that total for the main progress bar

**2) Main pass: insert/update `drug_raw`**  
- Streams the XML again and for each `<drug>`:
  - resolves `drug_pk` using `primary_drugbank_id`
  - converts the XML subtree into a **JSON-serializable dict** (tag/attrs/text/children)
  - writes it into `drug_raw.raw_json` using:
    - `INSERT ... ON DUPLICATE KEY UPDATE`  
      (safe to re-run; it updates the stored raw JSON)

---

#### MariaDB / MySQL notes
- The loader stores JSON as **TEXT/LONGTEXT** (no `CAST(... AS JSON)`).
- Recommended schema:
  - `ALTER TABLE drug_raw MODIFY raw_json LONGTEXT NOT NULL;`

If you hit `max_allowed_packet` issues (very large raw records), enable optional compression:
- set `COMPRESS_RAW = True`
- data will be stored as a string with prefix: `zlib+b64:...`

---

# Count rows in all DrugBank tables (prints one summary table)

import mysql.connector

MYSQL_HOST="localhost"; MYSQL_PORT=3306; MYSQL_USER="root"; MYSQL_PASSWORD=""; MYSQL_DB="drugbank"

TABLES = [
    "drug","drugbank_id_map","drug_group","drug_classification","drug_classification_alternative_parent",
    "drug_classification_substituent","drug_synonym","drug_salt","product","packager","manufacturer",
    "drug_category","drug_affected_organism","drug_dosage","drug_ahfs_code","drug_pdb_entry",
    "drug_atc_code","drug_atc_level","drug_price","drug_patent","drug_food_interaction","drug_interaction",
    "drug_external_identifier","drug_external_link","drug_property","reaction","drug_salt_drugbank_id",
    "reference_item","drug_reference","interactant_reference","interactant","interactant_action",
    "interactant_polypeptide","polypeptide","polypeptide_external_identifier","polypeptide_synonym",
    "pathway","drug_pathway","pathway_drug_member","pathway_enzyme_member","reaction_element",
    "reaction_enzyme","drug_raw",
]

conn = mysql.connector.connect(host=MYSQL_HOST, port=MYSQL_PORT, user=MYSQL_USER,
                              password=MYSQL_PASSWORD, database=MYSQL_DB, autocommit=True)
cur = conn.cursor()

rows = []
for t in TABLES:
    cur.execute("SELECT COUNT(*) FROM information_schema.tables WHERE table_schema=%s AND table_name=%s", (MYSQL_DB, t))
    if int(cur.fetchone()[0]) == 0:
        rows.append((t, "N/A"))
        continue
    cur.execute(f"SELECT COUNT(*) FROM `{t}`")
    rows.append((t, f"{int(cur.fetchone()[0]):,}"))

cur.close(); conn.close()

w = max(len("Table"), *(len(t) for t,_ in rows))
print(f"{'Table'.ljust(w)}  Rows")
print("-"*(w+6))
for t,n in rows:
    print(f"{t.ljust(w)}  {n}")

#### Output / reporting
The stage prints:
- pre-pass drug count
- progress bar during insertion
- periodic status logs (writes + speed)
- final summary table:
  - attempted writes
  - affected rows (`rowcount`)
  - total rows in `drug_raw`
  - runtime and error stats

</small>


In [4]:
# Jupyter cell: Count rows in ALL tables (single summary table)
import mysql.connector

MYSQL_HOST = "localhost"
MYSQL_PORT = 3306
MYSQL_USER = "root"
MYSQL_PASSWORD = ""
MYSQL_DB = "drugbank"

def get_conn():
    return mysql.connector.connect(
        host=MYSQL_HOST,
        port=MYSQL_PORT,
        user=MYSQL_USER,
        password=MYSQL_PASSWORD,
        database=MYSQL_DB,
        autocommit=True,
    )

TABLES = [
    # Stage 1
    "drug",

    # Stage 2
    "drugbank_id_map",
    "drug_group",
    "drug_classification",
    "drug_classification_alternative_parent",
    "drug_classification_substituent",
    "drug_synonym",
    "drug_salt",
    "product",
    "packager",
    "manufacturer",
    "drug_category",
    "drug_affected_organism",
    "drug_dosage",
    "drug_ahfs_code",
    "drug_pdb_entry",
    "drug_atc_code",
    "drug_atc_level",
    "drug_price",
    "drug_patent",
    "drug_food_interaction",
    "drug_interaction",
    "drug_external_identifier",
    "drug_external_link",
    "drug_property",
    "reaction",

    # Stage 3
    "drug_salt_drugbank_id",

    # Stage 4
    "reference_item",
    "drug_reference",
    "interactant_reference",

    # Stage 7
    "interactant",
    "interactant_action",
    "interactant_polypeptide",

    # Stage 8
    "polypeptide",
    "polypeptide_external_identifier",
    "polypeptide_synonym",

    # Stage 5
    "pathway",
    "drug_pathway",
    "pathway_drug_member",
    "pathway_enzyme_member",

    # Stage 6
    "reaction_element",
    "reaction_enzyme",

    # Stage 9
    "drug_raw",
]

def table_exists(cur, table: str) -> bool:
    cur.execute(
        "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema=%s AND table_name=%s",
        (MYSQL_DB, table),
    )
    return int(cur.fetchone()[0]) == 1

def get_row_count(cur, table: str) -> int:
    cur.execute(f"SELECT COUNT(*) FROM `{table}`")
    return int(cur.fetchone()[0])

def print_table(rows, headers):
    cols = len(headers)
    widths = [len(h) for h in headers]
    for r in rows:
        for i in range(cols):
            widths[i] = max(widths[i], len(str(r[i])))

    def sep(ch="-"):
        return "+".join(ch * (w + 2) for w in widths)

    def render(r):
        return "|".join(" " + str(r[i]).ljust(widths[i]) + " " for i in range(cols))

    print(sep("-"))
    print(render(headers))
    print(sep("="))
    for r in rows:
        print(render(r))
        print(sep("-"))

conn = get_conn()
cur = conn.cursor()

rows = []
total_rows = 0

for t in TABLES:
    if not table_exists(cur, t):
        rows.append([t, "N/A (missing)"])
        continue
    n = get_row_count(cur, t)
    total_rows += n
    rows.append([t, f"{n:,}"])

cur.close()
conn.close()

print_table(rows, headers=["Table", "Row count"])
print(f"\nTotal rows across listed tables (sum of COUNT(*)): {total_rows:,}")


----------------------------------------+-----------
 Table                                  | Row count 
 drug                                   | 17,430    
----------------------------------------+-----------
 drugbank_id_map                        | 22,428    
----------------------------------------+-----------
 drug_group                             | 20,354    
----------------------------------------+-----------
 drug_classification                    | 11,154    
----------------------------------------+-----------
 drug_classification_alternative_parent | 0         
----------------------------------------+-----------
 drug_classification_substituent        | 0         
----------------------------------------+-----------
 drug_synonym                           | 39,977    
----------------------------------------+-----------
 drug_salt                              | 2,922     
----------------------------------------+-----------
 product                                | 462,

In [None]:
### Cross Check

In [1]:
from lxml import etree as ET

XML_PATH = r"db/drugbank.xml"

ns = None
ctx = ET.iterparse(XML_PATH, events=("start",))
_, root = next(ctx)
ns = root.tag.split("}")[0].strip("{") if root.tag.startswith("{") else None

def nstag(x): 
    return f"{{{ns}}}{x}" if ns else x

count = 0
ctx2 = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
for _, e in ctx2:
    if e.tag == nstag("ahfs-codes"):
        count += 1
        e.clear()
        break
print("Found <ahfs-codes>:", count)


Found <ahfs-codes>: 1


In [2]:
count_ap = 0
count_sub = 0
ctx2 = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
for _, e in ctx2:
    if e.tag == nstag("alternative-parents"):
        count_ap += 1
        break
    if e.tag == nstag("substituents"):
        count_sub += 1
        break
print("Found <alternative-parents>:", count_ap)
print("Found <substituents>:", count_sub)


Found <alternative-parents>: 0
Found <substituents>: 0


In [3]:
count_pw_drugs = 0
ctx2 = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
for _, e in ctx2:
    if e.tag == nstag("drugs") and e.getparent() is not None and e.getparent().tag == nstag("pathway"):
        count_pw_drugs += 1
        break
print("Found <pathway><drugs>:", count_pw_drugs)


Found <pathway><drugs>: 1


In [4]:
from lxml import etree as ET

XML_PATH = r"db/drugbank.xml"

# detect namespace
ctx = ET.iterparse(XML_PATH, events=("start",))
_, root = next(ctx)
ns = root.tag.split("}")[0].strip("{") if root.tag.startswith("{") else None
def nstag(x): return f"{{{ns}}}{x}" if ns else x

targets = {
    "ahfs-codes": False,
    "alternative-parents": False,
    "substituents": False,
    "pathway_drugs_block": False,  # <pathway><drugs>
}

ctx2 = ET.iterparse(XML_PATH, events=("end",), huge_tree=True)
for _, e in ctx2:
    tag = e.tag
    if tag == nstag("ahfs-codes"):
        targets["ahfs-codes"] = True
    if tag == nstag("alternative-parents"):
        targets["alternative-parents"] = True
    if tag == nstag("substituents"):
        targets["substituents"] = True
    if tag == nstag("drugs") and e.getparent() is not None and e.getparent().tag == nstag("pathway"):
        targets["pathway_drugs_block"] = True

    if all(targets.values()):
        break

    e.clear()
    while e.getprevious() is not None:
        del e.getparent()[0]

targets


{'ahfs-codes': True,
 'alternative-parents': False,
 'substituents': False,
 'pathway_drugs_block': True}