In [None]:
"""
YDF Log Parser - Notebook Interface
Parse timing logs from YDF synthetic-data benchmarks and convert to CSV
"""

import pandas as pd
import csv
import re
import os
from collections import defaultdict
from typing import Dict, Any

# ============================================================================
# CONFIGURATION & CONSTANTS
# ============================================================================

ORDER_EXACT = [
    "Selecting Bootstrapped Samples",
    "Initialization of FindBestCondOblique",
    "SampleProjection", "ApplyProjection",
    "Bucket Allocation & Initialization=0",
    "Filling & Finalizing the Buckets", "SortFeature", "ScanSplits",
    "Post-processing after Training all Trees",
    "EvaluateProjection",
    "FillExampleBucketSet (next 3 calls)",
]

ORDER_HISTOGRAM = [
    "Selecting Bootstrapped Samples",
    "Initialization of FindBestCondOblique",
    "SampleProjection", "ApplyProjection",
    "Initializing Histogram Bins",
    "Setting Split Distributions",
    "Looping over samples",
    "Looping over splits",
    "Finding best threshold (Computing Entropies)",
    "Post-processing after Training all Trees",
]

RENAMES = {
    "Post-processing after Train": "Post-processing after Training all Trees",
    "FillExampleBucketSet (calls 3 above)": "FillExampleBucketSet (next 3 calls)",
}

TRAIN_RX = re.compile(r"Training wall-time:\s*([0-9.eE+-]+)s")
BOOT_TAG = "Selecting Bootstrapped Samples"
DEPTH_TAG = "Depth "
TOOK_TAG = " took:"
STRIP_SET = " \t-"

# ============================================================================
# PARSING FUNCTIONS
# ============================================================================

def fast_parse_tree_depth(log: str, split_type: str = "Exact") -> pd.DataFrame:
    """Parse timing data from YDF log output."""
    def _num(tok: str) -> float:
        tok = tok.rstrip()
        if tok.endswith('s'):
            tok = tok[:-1]
        return float(tok)
    
    rows: list[tuple[int, int, str, float]] = []
    node_counts: defaultdict[tuple[int, int], int] = defaultdict(int)

    ORDER = ORDER_HISTOGRAM if split_type in ["Random", "Equal Width", "Histogram"] else ORDER_EXACT

    cur_tree = -1
    cur_depth: int | None = None

    for line in log.splitlines():
        # New tree (depth 0)
        if BOOT_TAG in line:
            cur_tree += 1
            node_counts[(cur_tree, 0)] += 1
            rows.append((cur_tree, 0, ORDER[0], _num(line.rsplit(maxsplit=1)[-1])))
            cur_depth = None
            continue

        # Depth header
        if line.lstrip().startswith(DEPTH_TAG):
            cur_depth = int(line.lstrip()[len(DEPTH_TAG):].split()[0])
            node_counts[(cur_tree, cur_depth)] += 1
            continue

        # Skip lines until at least one tree seen
        if cur_tree < 0 or TOOK_TAG not in line:
            continue

        # Timing line
        name_part, _, rest = line.partition(TOOK_TAG)
        time_s = _num(rest.split()[0])

        clean = name_part.lstrip(STRIP_SET).rstrip()
        clean = RENAMES.get(clean, clean)

        rows.append((cur_tree, cur_depth, clean, time_s))

    if not rows:
        raise ValueError("No timing lines parsed")

    df = pd.DataFrame(rows, columns=["tree", "depth", "function", "time_s"])

    wide = (
        df.pivot_table(index=["tree", "depth"],
                       columns="function",
                       values="time_s",
                       aggfunc="sum",
                       fill_value=0.0)
          .reindex(columns=ORDER, fill_value=0.0)
          .reset_index()
    )

    # Merge the node counts
    counts_df = pd.DataFrame(
        [(t, d, c) for (t, d), c in node_counts.items()],
        columns=["tree", "depth", "nodes"]
    )
    wide = wide.merge(counts_df, on=["tree", "depth"], how="left")
    wide["nodes"] = wide["nodes"].fillna(0).astype(int)

    cols = ["tree", "depth", "nodes"] + ORDER
    return wide[cols]

def write_csv(table: pd.DataFrame, params: Dict[str, Any], path: str):
    """Write timing table left-aligned, params block to the right (after 2 blanks)."""
    p_df = pd.DataFrame(list(params.items()), columns=["Parameter", "Value"])

    n_rows = max(len(table), len(p_df))
    tbl = table.reindex(range(n_rows)).fillna("")
    p_df = p_df.reindex(range(n_rows)).fillna("")
    gap = pd.DataFrame({"": [""] * n_rows, "  ": [""] * n_rows})

    pd.concat([tbl, gap, p_df], axis=1).to_csv(
        path, index=False, quoting=csv.QUOTE_MINIMAL
    )

def extract_wall_time(log: str) -> str:
    """Extract wall time from log for filename."""
    match = TRAIN_RX.search(log)
    if match:
        return match.group(1)
    else:
        import time
        return f"unknown_{int(time.time())}"

def strip_ansi_codes(text: str) -> str:
    """Remove ANSI escape codes from text."""
    return re.sub(r'\x1B\[[0-?]*[ -/]*[@-~]', '', text)

# ============================================================================
# MAIN INTERFACE
# ============================================================================

def parse_log_file(log_path: str, 
                   output_dir: str = None,
                   split_type: str = "Exact",
                   custom_params: Dict[str, Any] = None) -> pd.DataFrame:
    """
    Parse a YDF log file and optionally save to CSV.
    
    Args:
        log_path: Path to the log file
        output_dir: Directory to save CSV (if None, no CSV is saved)
        split_type: Type of split used ("Exact", "Random", "Equal Width", "Histogram")
        custom_params: Custom parameters to include in CSV
    
    Returns:
        Parsed timing DataFrame
    """
    print(f"📖 Reading log file: {log_path}")
    
    with open(log_path, 'r') as f:
        log_content = f.read()
    
    print(f"📝 Log file size: {len(log_content)} characters")
    
    # Strip ANSI codes
    clean_log = strip_ansi_codes(log_content)
    
    # Parse the log
    print(f"⚙️  Parsing with split_type: {split_type}")
    table = fast_parse_tree_depth(clean_log, split_type)
    
    print(f"✅ Parsed {len(table)} rows")
    print(f"   Trees: {table['tree'].nunique()}")
    print(f"   Max depth: {table['depth'].max()}")
    
    # Extract wall time
    wall_time = extract_wall_time(clean_log)
    print(f"⏱️  Wall time: {wall_time}s")
    
    # Save CSV if output directory provided
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        csv_path = os.path.join(output_dir, f"{wall_time}.csv")
        
        # Default parameters
        params = {
            "wall_time": wall_time,
            "split_type": split_type,
            "log_source": os.path.basename(log_path),
        }
        
        # Add custom parameters if provided
        if custom_params:
            params.update(custom_params)
        
        write_csv(table, params, csv_path)
        print(f"💾 CSV saved to: {csv_path}")
    
    return table

def parse_log_string(log_content: str,
                     output_path: str = None,
                     split_type: str = "Exact",
                     custom_params: Dict[str, Any] = None) -> pd.DataFrame:
    """
    Parse a YDF log from string content.
    
    Args:
        log_content: Raw log content as string
        output_path: Full path to save CSV (if None, no CSV is saved)
        split_type: Type of split used ("Exact", "Random", "Equal Width", "Histogram")
        custom_params: Custom parameters to include in CSV
    
    Returns:
        Parsed timing DataFrame
    """
    print(f"📝 Processing log content ({len(log_content)} characters)")
    
    # Strip ANSI codes
    clean_log = strip_ansi_codes(log_content)
    
    # Parse the log
    print(f"⚙️  Parsing with split_type: {split_type}")
    table = fast_parse_tree_depth(clean_log, split_type)
    
    print(f"✅ Parsed {len(table)} rows")
    print(f"   Trees: {table['tree'].nunique()}")
    print(f"   Max depth: {table['depth'].max()}")
    
    # Extract wall time
    wall_time = extract_wall_time(clean_log)
    print(f"⏱️  Wall time: {wall_time}s")
    
    # Save CSV if output path provided
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        
        # Default parameters
        params = {
            "wall_time": wall_time,
            "split_type": split_type,
            "log_source": "string_input",
        }
        
        # Add custom parameters if provided
        if custom_params:
            params.update(custom_params)
        
        write_csv(table, params, output_path)
        print(f"💾 CSV saved to: {output_path}")
    
    return table

# ============================================================================
# EXAMPLE USAGE
# ============================================================================

if __name__ == "__main__":
    # Example 1: Parse a single log file
    log_file = "path/to/your/logfile.log"
    output_directory = "parsed_results"
    
    # Custom parameters to include in the CSV
    my_params = {
        "rows": 4096,
        "cols": 4096,
        "num_trees": 5,
        "tree_depth": -1,
        "num_threads": 1,
        "experiment_name": "my_experiment",
        "cpu_model": "Intel_i7_something"
    }
    
    # Parse and save
    df = parse_log_file(log_file, output_directory, "Exact", my_params)
    
    # Example 2: Parse from string (useful in notebooks)
    # log_string = """
    # Your log content here...
    # """
    # # df = parse_log_string(log_string, "output.csv", "Exact", my_params)
    
    # print("Ready to parse! Uncomment the examples above or use the functions directly.")