In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Parallel inventory of CCTA DICOMs by accession (E100..., E101..., etc.), one row per SeriesInstanceUID.
- Fast pass over all files using specific_tags to group + aggregate without pixel data.
- Full header read only once per series (representative file) to export *all* header fields.
- Parallelized over accession folders.

Usage:
    python inventory_ccta_full_parallel.py \
        --root /home/eo287/mnt/s3_ccta/cta_09232025/studies \
        --out  /home/eo287/mnt/s3_ccta/summaries/ccta_series_inventory_fullheaders.csv \
        --workers 8

Env:
    python>=3.11, pydicom, pandas, tqdm
"""

import os
import sys
import json
import math
import argparse
from typing import Any, Dict, List, Tuple, Optional
from collections import defaultdict, Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from multiprocessing import cpu_count

import pandas as pd
from tqdm import tqdm

import pydicom
from pydicom.tag import BaseTag
from pydicom.datadict import keyword_for_tag, dictionary_description
from pydicom.errors import InvalidDicomError

# ----------------- Defaults / Tunables -----------------
DEFAULT_ROOT = "/home/eo287/mnt/s3_ccta/cta_09232025/studies"
DEFAULT_OUT = "/home/eo287/mnt/s3_ccta/summaries/ccta_series_inventory_fullheaders.csv"

# Flattening controls
FLATTEN_SEQUENCES: bool = True
MAX_SEQ_DEPTH: int = 2         # flatten nested sequences up to this depth
MAX_STRING_LEN: int = 2000     # truncate very long string-ish values for CSV hygiene
INCLUDE_PRIVATE: bool = True   # keep private tags
EXCLUDE_PIXELDATA: bool = True # drop PixelData and very large binary blobs

# Fast-pass tags (IO-efficient; used to group/aggregate without loading full headers):
FAST_TAGS = ["SeriesInstanceUID", "InstanceNumber", "ImagePositionPatient", "AcquisitionTime"]

# ----------------- Helpers -----------------
def is_dicom_path(path: str) -> bool:
    return path.lower().endswith(".dcm")

def safe_read_header_full(path: str) -> Optional[pydicom.dataset.FileDataset]:
    """Full header (top-level) without pixel data; used once per series."""
    try:
        return pydicom.dcmread(path, stop_before_pixels=True, force=True)
    except (InvalidDicomError, Exception):
        return None

def safe_read_header_fast(path: str) -> Optional[pydicom.dataset.FileDataset]:
    """Fast, partial header for grouping/aggregates (very small tag set)."""
    try:
        return pydicom.dcmread(path, stop_before_pixels=True, force=True, specific_tags=FAST_TAGS)
    except (InvalidDicomError, Exception):
        return None

def tag_to_key(tag: BaseTag) -> str:
    """Prefer keyword; else human name; else Private_(ggggeeee)."""
    kw = keyword_for_tag(tag)
    if kw:
        return kw
    name = dictionary_description(tag)
    if name:
        return name.replace(" ", "")
    return f"Private_({int(tag):08X})"

def _truncate(v: str) -> str:
    if v is None:
        return "NA"
    if isinstance(v, str) and len(v) > MAX_STRING_LEN:
        return v[:MAX_STRING_LEN] + "...<truncated>"
    return str(v)

def element_to_value(elem: pydicom.dataelem.DataElement, depth: int = 0) -> Any:
    """Serialize a DataElement value; flatten sequences if requested."""
    if EXCLUDE_PIXELDATA and elem.keyword == "PixelData":
        return "<PixelData omitted>"
    val = elem.value
    # Sequence flattening
    if elem.VR == "SQ" and FLATTEN_SEQUENCES:
        if depth >= MAX_SEQ_DEPTH:
            try:
                return _truncate(json.dumps([f"Item{idx}" for idx, _ in enumerate(val)], ensure_ascii=False))
            except Exception:
                return f"<SQ depth>{len(val)} items"
        items: List[Dict[str, Any]] = []
        for it in val:
            items.append(dataset_to_dict(it, depth=depth + 1))
        try:
            return _truncate(json.dumps(items, ensure_ascii=False))
        except Exception:
            return _truncate(str(items))
    # Multi-values normalize
    if isinstance(val, (list, tuple)):
        return _truncate("|".join(_truncate(str(x)) for x in val))
    # Numeric NaNs to "NA"
    if isinstance(val, float) and (math.isnan(val) or math.isinf(val)):
        return "NA"
    return _truncate(str(val))

def dataset_to_dict(ds: pydicom.dataset.Dataset, depth: int = 0) -> Dict[str, Any]:
    out: Dict[str, Any] = {}
    for elem in ds:
        if not INCLUDE_PRIVATE and elem.tag.is_private:
            continue
        key = elem.keyword or tag_to_key(elem.tag)
        try:
            out[key] = element_to_value(elem, depth=depth)
        except Exception:
            out[key] = "<serialization_error>"
    return out

# ----------------- Accession processing -----------------
def process_accession(acc_root: str, acc_name: str) -> Tuple[List[Dict[str, Any]], set]:
    """
    Process a single accession (runs in a worker).
    Returns list of rows (dicts) and the union of keys observed.
    """
    acc_path = os.path.join(acc_root, acc_name)

    # First pass: map files to series and collect aggregates without full reads
    # series_map: siuid -> list of filepaths
    series_map: Dict[str, List[str]] = defaultdict(list)
    # aggregates per series we can compute on the fly:
    inst_nums: Dict[str, List[int]] = defaultdict(list)
    z_coords: Dict[str, List[float]] = defaultdict(list)
    acq_times: Dict[str, List[str]] = defaultdict(list)
    read_errors: Dict[str, int] = defaultdict(int)

    # For picking representative later:
    # maintain (min_instance_number, filepath) for each series
    rep_candidate: Dict[str, Tuple[int, str]] = {}

    for dirpath, _, files in os.walk(acc_path):
        for fn in files:
            if not is_dicom_path(fn):
                continue
            fpath = os.path.join(dirpath, fn)
            ds = safe_read_header_fast(fpath)
            if ds is None:
                # can't even find series; skip
                continue
            siuid = getattr(ds, "SeriesInstanceUID", None)
            if not siuid:
                continue
            siuid = str(siuid)
            series_map[siuid].append(fpath)

            # aggregates
            inum = getattr(ds, "InstanceNumber", None)
            if inum is not None:
                try:
                    inum_int = int(inum)
                    inst_nums[siuid].append(inum_int)
                    # candidate rep: smallest instance number
                    prev = rep_candidate.get(siuid)
                    if prev is None or inum_int < prev[0]:
                        rep_candidate[siuid] = (inum_int, fpath)
                except Exception:
                    pass
            else:
                # no instance number: prefer first encountered if no candidate
                if siuid not in rep_candidate:
                    rep_candidate[siuid] = (10**9, fpath)  # large placeholder

            ipp = getattr(ds, "ImagePositionPatient", None)
            if ipp and isinstance(ipp, (list, tuple)) and len(ipp) == 3:
                try:
                    z_coords[siuid].append(float(ipp[2]))
                except Exception:
                    pass

            at = getattr(ds, "AcquisitionTime", None)
            if at:
                acq_times[siuid].append(str(at))

    # Build rows: for each series, read *full header once* from representative
    rows: List[Dict[str, Any]] = []
    union_keys: set = set()

    for siuid, paths in series_map.items():
        # Representative
        rep_path = rep_candidate.get(siuid, (None, None))[1] or paths[0]
        ds_full = safe_read_header_full(rep_path)
        if ds_full is None:
            # try to find any other readable file as representative
            for candidate in paths:
                ds_full = safe_read_header_full(candidate)
                if ds_full is not None:
                    rep_path = candidate
                    break
        if ds_full is None:
            # skip this series entirely
            continue

        # header dict (flattened)
        header_dict = dataset_to_dict(ds_full, depth=0)

        # aggregates
        n_images = len(paths)
        inst_min = min(inst_nums[siuid]) if inst_nums.get(siuid) else "NA"
        inst_max = max(inst_nums[siuid]) if inst_nums.get(siuid) else "NA"
        if z_coords.get(siuid):
            zspan = max(z_coords[siuid]) - min(z_coords[siuid])
        else:
            zspan = "NA"
        acq_mode = Counter(acq_times.get(siuid, [])).most_common(1)[0][0] if acq_times.get(siuid) else "NA"

        row: Dict[str, Any] = {
            "AccessionFolder": acc_name,
            "SeriesInstanceUID": siuid,
            "NumImages": n_images,
            "InstanceNumberMin": inst_min,
            "InstanceNumberMax": inst_max,
            "ZSpan": zspan,
            "AcquisitionTimeMode": acq_mode,
            "NumHeaderReadErrors": 0,  # retained for symmetry; we skip unreadables earlier
            "RepresentativeFile": rep_path,
        }
        row.update(header_dict)
        rows.append(row)
        union_keys.update(row.keys())

    return rows, union_keys

# ----------------- Main -----------------
def main():
    global MAX_SEQ_DEPTH, INCLUDE_PRIVATE, EXCLUDE_PIXELDATA

    ap = argparse.ArgumentParser(description="Parallel CCTA DICOM series inventory (full headers).")
    ap.add_argument("--root", default=DEFAULT_ROOT,
                    help="Root folder containing accession folders (E100..., E101...)")
    ap.add_argument("--out",  default=DEFAULT_OUT,
                    help="Output CSV path")
    ap.add_argument("--workers", type=int, default=min(8, cpu_count()),
                    help="Number of parallel worker processes (default=min(8, cpu_count()))")
    # safe to reference MAX_SEQ_DEPTH now that global is declared
    ap.add_argument("--max-seq-depth", type=int, default=MAX_SEQ_DEPTH,
                    help="Max sequence flattening depth")
    ap.add_argument("--include-private", action="store_true", default=INCLUDE_PRIVATE,
                    help="Include private tags")
    ap.add_argument("--keep-pixeldata", action="store_true",
                    help="Include PixelData (not recommended)")

    args = ap.parse_args()

    # Update globals from CLI
    MAX_SEQ_DEPTH   = args.max_seq_depth
    INCLUDE_PRIVATE = args.include_private
    EXCLUDE_PIXELDATA = not args.keep_pixeldata

    root = args.root
    out_csv = args.out

    if not os.path.isdir(root):
        print(f"ERROR: root not found: {root}")
        sys.exit(2)

    # Discover accession folders
    accessions = sorted(
        d for d in os.listdir(root)
        if os.path.isdir(os.path.join(root, d)) and d[0].upper() == "E"
    )
    if not accessions:
        print("No accession folders (E...) found under:", root)
        sys.exit(0)

    all_rows: List[Dict[str, Any]] = []
    union_keys: set = set()

    # Parallel map over accessions
    with ProcessPoolExecutor(max_workers=args.workers) as ex:
        futures = {ex.submit(process_accession, root, acc): acc for acc in accessions}
        for fut in tqdm(as_completed(futures), total=len(futures), desc="Accessions (parallel)"):
            acc = futures[fut]
            try:
                rows, keys = fut.result()
                if rows:
                    all_rows.extend(rows)
                    union_keys.update(keys)
            except Exception as e:
                print(f"[WARN] Accession {acc} failed: {e}", file=sys.stderr)

    if not all_rows:
        print("No series discovered; nothing to write.")
        sys.exit(0)

    # Column ordering
    base_cols = [
        "AccessionFolder", "SeriesInstanceUID",
        "NumImages", "InstanceNumberMin", "InstanceNumberMax", "ZSpan",
        "AcquisitionTimeMode", "NumHeaderReadErrors", "RepresentativeFile",
    ]
    likely_keys = [
        "AccessionNumber", "StudyInstanceUID", "StudyID", "StudyDate", "StudyTime", "StudyDescription",
        "PatientID", "PatientSex", "PatientBirthDate",
        "SeriesNumber", "SeriesDescription", "Modality",
        "Manufacturer", "ManufacturerModelName", "InstitutionName",
    ]
    columns = list(base_cols)
    for k in likely_keys:
        if k in union_keys and k not in columns:
            columns.append(k)
    for k in sorted(union_keys):
        if k not in columns:
            columns.append(k)

    df = pd.DataFrame(all_rows, columns=columns)

    sort_keys = [c for c in ["AccessionFolder", "StudyDate", "SeriesNumber"] if c in df.columns]
    if sort_keys:
        df = df.sort_values(sort_keys, kind="stable")

    os.makedirs(os.path.dirname(out_csv), exist_ok=True)
    df.to_csv(out_csv, index=False)

    print(f"Wrote {len(df):,} series rows with {len(df.columns):,} columns to:\n  {out_csv}")
    with pd.option_context("display.max_columns", 20, "display.width", 220):
        print(df.head(5))

if __name__ == "__main__":
    sys.argv = ["inventory_ccta_full_parallel.py"]
    main()


usage: ipykernel_launcher.py [-h] [--root ROOT] [--out OUT]
                             [--workers WORKERS]
                             [--max-seq-depth MAX_SEQ_DEPTH]
                             [--include-private] [--keep-pixeldata]
ipykernel_launcher.py: error: unrecognized arguments: --f=/run/user/1974645248/jupyter/runtime/kernel-v3cfc045a8cfd50a2bbd6353478e0ded8e5b70f98a.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
