In [1]:
#!/usr/bin/env python3
"""
Step 0: Reproducible Environment (Colab/Jupyter adapted - top-conf/journal grade)
Generate a complete reproducible environment configuration
"""

# ===== Set environment variables directly (Colab/Jupyter env) =====
import os
import sys

os.environ["PYTHONHASHSEED"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

print("⚠️ Note: In Jupyter/Colab, PYTHONHASHSEED must be set before the kernel starts")
print("   Suggestion: After setting environment variables, restart the runtime, then run the main code\n")

# ===== Environment variables set; continue normal flow =====
import json
import hashlib
import subprocess
from pathlib import Path
from datetime import datetime, timezone
from contextlib import redirect_stdout
import io

# Check Python version
assert sys.version_info >= (3, 10), f"Require Python ≥ 3.10, current: {sys.version}"

# Create output directory
output_dir = Path("artifacts/env")
output_dir.mkdir(parents=True, exist_ok=True)

# 1. Multiple random seeds (0–9)
SEEDS = list(range(10))
print(f"Configured random seeds: {SEEDS}")

# Import and configure
import random
import numpy as np
import torch

# Initialize with the first seed
random.seed(SEEDS[0])
np.random.seed(SEEDS[0])
torch.manual_seed(SEEDS[0])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEEDS[0])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Disable TF32
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

# Enable strict deterministic algorithms (not using warn_only)
torch.use_deterministic_algorithms(True)

# Set matmul precision
if hasattr(torch, 'set_float32_matmul_precision'):
    torch.set_float32_matmul_precision("high")

# 2. Generate SEEDS.yaml (with fallback)
print("Generating SEEDS.yaml...")
seeds_config = {
    "seeds": SEEDS,
    "default_seed": SEEDS[0],
    "description": "Random seeds for python, numpy, torch, sklearn"
}
try:
    import yaml
    with open(output_dir / "SEEDS.yaml", "w") as f:
        yaml.dump(seeds_config, f, default_flow_style=False)
except ImportError:
    # Fallback if PyYAML is not installed
    yaml_content = f"""seeds: {SEEDS}
default_seed: {SEEDS[0]}
description: Random seeds for python, numpy, torch, sklearn
"""
    with open(output_dir / "SEEDS.yaml", "w") as f:
        f.write(yaml_content)

# 3. Generate requirements.txt (frozen versions)
print("Generating requirements.txt...")
result = subprocess.run(
    [sys.executable, "-m", "pip", "freeze"],
    capture_output=True, text=True
)
requirements = result.stdout
with open(output_dir / "requirements.txt", "w") as f:
    f.write(requirements)

# 4. Collect system info (for env.txt header)
import platform
system_info = []
system_info.append("="*60)
system_info.append("Environment Snapshot - System Overview")
system_info.append("="*60)
system_info.append(f"Time (UTC): {datetime.now(timezone.utc).isoformat()}")
system_info.append(f"Python: {sys.version}")
system_info.append(f"Platform: {platform.system()} {platform.release()} ({platform.machine()})")

try:
    import psutil
    system_info.append(f"CPU: {psutil.cpu_count(logical=False)} cores / {psutil.cpu_count(logical=True)} threads")
    system_info.append(f"Memory: {round(psutil.virtual_memory().total / (1024**3), 2)} GB")
except ImportError:
    pass

system_info.append(f"PyTorch: {torch.__version__}")
if torch.cuda.is_available():
    system_info.append(f"CUDA: {torch.version.cuda}")
    system_info.append(f"cuDNN: {torch.backends.cudnn.version()}")
    try:
        out = subprocess.run(
            ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
            capture_output=True, text=True
        )
        if out.returncode == 0 and out.stdout.strip():
            system_info.append(f"NVIDIA driver: {out.stdout.strip().splitlines()[0]}")
    except:
        pass

system_info.append("\nEnvironment variables:")
for key in ["PYTHONHASHSEED", "CUBLAS_WORKSPACE_CONFIG", "OMP_NUM_THREADS",
            "MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", "NUMEXPR_NUM_THREADS"]:
    system_info.append(f"  {key}={os.environ.get(key, 'N/A')}")

system_info.append("\n" + "="*60)
system_info.append("Installed packages list")
system_info.append("="*60 + "\n")

# 5. Generate env.txt (human-readable + system summary)
print("Generating env.txt...")
result = subprocess.run(
    [sys.executable, "-m", "pip", "list"],
    capture_output=True, text=True
)
with open(output_dir / "env.txt", "w") as f:
    f.write("\n".join(system_info))
    f.write(result.stdout)

# 6. Generate environment.yml
print("Generating environment.yml...")
env_yml = f"""name: har_lara
channels:
  - defaults
  - conda-forge
dependencies:
  - python={sys.version_info.major}.{sys.version_info.minor}
  - pip
  - pip:
"""
for line in requirements.strip().split("\n"):
    if line and not line.startswith("#"):
        env_yml += f"      - {line}\n"

with open(output_dir / "environment.yml", "w") as f:
    f.write(env_yml)

# 7. Collect complete hardware information
print("Collecting hardware information...")
hardware_info = {
    "timestamp_utc": datetime.now(timezone.utc).isoformat(),
    "python_version": sys.version,
    "python_executable": sys.executable,
    "platform": sys.platform,
    "os": platform.system(),
    "os_release": platform.release(),
    "os_version": platform.version(),
    "machine": platform.machine(),
    "processor": platform.processor(),
}

try:
    import psutil
    hardware_info["cpu_count_physical"] = psutil.cpu_count(logical=False)
    hardware_info["cpu_count_logical"] = psutil.cpu_count(logical=True)
    hardware_info["memory_total_gb"] = round(psutil.virtual_memory().total / (1024**3), 2)
except ImportError:
    pass

hardware_info["torch_version"] = torch.__version__

if torch.cuda.is_available():
    hardware_info["gpu_available"] = True
    hardware_info["gpu_count"] = torch.cuda.device_count()
    hardware_info["gpu_names"] = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
    hardware_info["cuda_version"] = torch.version.cuda
    hardware_info["cudnn_version"] = torch.backends.cudnn.version()

    gpu_details = []
    for i in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(i)
        gpu_details.append({
            "id": i,
            "name": props.name,
            "compute_capability": f"{props.major}.{props.minor}",
            "total_memory_gb": round(props.total_memory / (1024**3), 2),
            "multi_processor_count": props.multi_processor_count
        })
    hardware_info["gpu_details"] = gpu_details

    try:
        out = subprocess.run(
            ["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
            capture_output=True, text=True
        )
        if out.returncode == 0 and out.stdout.strip():
            hardware_info["nvidia_driver_version"] = out.stdout.strip().splitlines()[0]
    except:
        pass
else:
    hardware_info["gpu_available"] = False

hardware_info["deterministic_config"] = {
    "cudnn_deterministic": torch.backends.cudnn.deterministic,
    "cudnn_benchmark": torch.backends.cudnn.benchmark,
    "use_deterministic_algorithms": True,
    "warn_only": False,
    "tf32_disabled": not torch.backends.cuda.matmul.allow_tf32 if torch.cuda.is_available() else "N/A",
    "float32_matmul_precision": "high" if hasattr(torch, 'set_float32_matmul_precision') else "N/A",
    "PYTHONHASHSEED": os.environ.get("PYTHONHASHSEED"),
    "CUBLAS_WORKSPACE_CONFIG": os.environ.get("CUBLAS_WORKSPACE_CONFIG"),
    "OMP_NUM_THREADS": os.environ.get("OMP_NUM_THREADS"),
    "MKL_NUM_THREADS": os.environ.get("MKL_NUM_THREADS"),
    "OPENBLAS_NUM_THREADS": os.environ.get("OPENBLAS_NUM_THREADS"),
    "NUMEXPR_NUM_THREADS": os.environ.get("NUMEXPR_NUM_THREADS"),
}

with open(output_dir / "hardware_log.json", "w") as f:
    json.dump(hardware_info, f, indent=2)

# 8. Git commit + dirty flag
print("Collecting Git information...")
git_info = {}
try:
    git_commit = subprocess.run(
        ["git", "rev-parse", "HEAD"],
        capture_output=True, text=True, check=True
    ).stdout.strip()
    git_info["commit"] = git_commit

    git_branch = subprocess.run(
        ["git", "rev-parse", "--abbrev-ref", "HEAD"],
        capture_output=True, text=True, check=True
    ).stdout.strip()
    git_info["branch"] = git_branch

    dirty = subprocess.run(
        ["git", "status", "--porcelain"],
        capture_output=True, text=True
    ).stdout.strip()
    git_info["dirty"] = bool(dirty)
except:
    git_info["commit"] = "N/A (not a git repo)"
    git_info["dirty"] = False

with open(output_dir / "git_info.json", "w") as f:
    json.dump(git_info, f, indent=2)

# 9. PyTorch build information
print("Saving PyTorch build information...")
try:
    buf = io.StringIO()
    with redirect_stdout(buf):
        torch.__config__.show()
    (output_dir / "torch_build.txt").write_text(buf.getvalue(), encoding="utf-8")
except:
    pass

# 10. Data checksums (only original archives)
print("Generating data checksums...")
data_dir = Path("data")
if data_dir.exists():
    sha256sums = []
    archive_exts = {'.zip', '.tar', '.gz', '.tgz', '.bz2', '.xz', '.7z', '.rar'}
    for file_path in sorted(data_dir.rglob("*")):
        if file_path.is_file() and file_path.suffix.lower() in archive_exts:
            sha256 = hashlib.sha256()
            with open(file_path, "rb") as f:
                for chunk in iter(lambda: f.read(65536), b""):
                    sha256.update(chunk)
            rel_path = file_path.relative_to(data_dir)
            sha256sums.append(f"{sha256.hexdigest()}  {rel_path}")

    if sha256sums:
        with open(output_dir / "data_SHA256SUMS.txt", "w") as f:
            f.write("\n".join(sha256sums))
        print(f"  Generated checksums for {len(sha256sums)} archives")
    else:
        print("  No archives in data/ directory; skipping checksums")
else:
    print("  data/ directory does not exist; skipping checksums")

# 11. Compute environment hashes of all key files
print("Computing environment hashes...")
env_files = [
    "requirements.txt",
    "environment.yml",
    "env.txt",
    "SEEDS.yaml",
    "hardware_log.json",
    "git_info.json"
]
sha256_lines = []
for filename in env_files:
    filepath = output_dir / filename
    if filepath.exists():
        sha256 = hashlib.sha256()
        with open(filepath, "rb") as f:
            sha256.update(f.read())
        sha256_lines.append(f"{sha256.hexdigest()}  {filename}")

with open(output_dir / "ENV.SHA256", "w") as f:
    f.write("\n".join(sha256_lines))

# Output summary
print("\n" + "="*60)
print("Step 0 complete - Reproducible environment configuration (top-conf/journal grade)")
print("="*60)
print(f"Output directory: {output_dir}/")
print(f"  ✓ SEEDS.yaml (seeds: {SEEDS})")
print(f"  ✓ requirements.txt")
print(f"  ✓ env.txt (with system summary)")
print(f"  ✓ environment.yml")
print(f"  ✓ hardware_log.json")
print(f"  ✓ git_info.json (dirty={git_info.get('dirty', False)})")
print(f"  ✓ torch_build.txt")
print(f"  ✓ ENV.SHA256 (covers all key files)")
if (output_dir / "data_SHA256SUMS.txt").exists():
    print(f"  ✓ data_SHA256SUMS.txt (archives only)")

print(f"\nStrict determinism configuration:")
print(f"  - torch.use_deterministic_algorithms: True (warn_only=False)")
print(f"  - cudnn.deterministic: {torch.backends.cudnn.deterministic}")
print(f"  - cudnn.benchmark: {torch.backends.cudnn.benchmark}")
if torch.cuda.is_available():
    print(f"  - TF32 disabled: {not torch.backends.cuda.matmul.allow_tf32}")
print(f"  - Environment variables set:")
print(f"    PYTHONHASHSEED: {os.environ.get('PYTHONHASHSEED')}")
print(f"    CUBLAS_WORKSPACE_CONFIG: {os.environ.get('CUBLAS_WORKSPACE_CONFIG')}")
print(f"    Thread control: OMP/MKL/OPENBLAS/NUMEXPR=1")
print("="*60)

⚠️ Note: In Jupyter/Colab, PYTHONHASHSEED must be set before the kernel starts
   Suggestion: After setting environment variables, restart the runtime, then run the main code

Configured random seeds: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
Generating SEEDS.yaml...
Generating requirements.txt...
Generating env.txt...
Generating environment.yml...
Collecting hardware information...
Collecting Git information...
Saving PyTorch build information...
Generating data checksums...
  data/ directory does not exist; skipping checksums
Computing environment hashes...

Step 0 complete - Reproducible environment configuration (top-conf/journal grade)
Output directory: artifacts/env/
  ✓ SEEDS.yaml (seeds: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
  ✓ requirements.txt
  ✓ env.txt (with system summary)
  ✓ environment.yml
  ✓ hardware_log.json
  ✓ git_info.json (dirty=False)
  ✓ torch_build.txt
  ✓ ENV.SHA256 (covers all key files)

Strict determinism configuration:
  - torch.use_deterministic_algorithms: True (warn_o

In [2]:
#!/usr/bin/env python3
"""
Steps 1–2: Data Acquisition & Unpack Standardization (top-conf/journal grade)
Process the uploaded LARa MbientLab IMU archive
"""

import os
import hashlib
import zipfile
import shutil
import json
import re
import numpy as np
from pathlib import Path
from datetime import datetime, timezone
import pandas as pd

# ========== Helper functions ==========
def read_any_csv(path, nrows=None):
    """CSV reader with auto delimiter detection"""
    try:
        return pd.read_csv(path, nrows=nrows, sep=None, engine="python")
    except Exception:
        return pd.read_csv(path, nrows=nrows)

def infer_sampling_rate(df):
    """Infer sampling rate; auto-handle ns/μs/ms/s time units"""
    cols = [c.lower() for c in df.columns]
    time_cols = [c for c in df.columns if re.search(r"(time|timestamp|epoch)", c.lower())]
    if not time_cols:
        return None

    c = time_cols[0]
    t = pd.to_numeric(df[c], errors="coerce").dropna().to_numpy()
    if t.size < 3:
        return None

    # Infer time unit by magnitude
    max_val = np.nanmax(np.abs(t[:1000])) if t.size else 0
    if max_val >= 1e12:      # nanoseconds
        scale = 1e-9
    elif max_val >= 1e9:     # nanoseconds
        scale = 1e-9
    elif max_val >= 1e6:     # microseconds
        scale = 1e-6
    elif max_val >= 1e3:     # milliseconds
        scale = 1e-3
    else:                    # seconds
        scale = 1.0

    t_sec = t * scale
    dt = np.diff(t_sec)
    dt = dt[dt > 0]
    if dt.size == 0:
        return None

    # Use median for robustness
    return float(np.round(1.0 / np.median(dt), 3))

def infer_sensor_type(cols_lower, filename):
    """Infer sensor type"""
    if 'label' in filename.lower() or 'activity' in filename.lower():
        return "labels"

    sensors = []
    if any(("acc" in c) or ("accelerom" in c) for c in cols_lower):
        sensors.append("acc")
    if any(("gyro" in c) or re.search(r"\bgyr", c) for c in cols_lower):
        sensors.append("gyro")
    if any(("mag" in c) or ("magnetom" in c) for c in cols_lower):
        sensors.append("mag")

    return "+".join(sensors) if sensors else "unknown"

# LARa placement mapping (per official docs)
PLACEMENT_MAP = {
    "L01": "lwrist",      # Left wrist
    "L02": "rwrist",      # Right wrist
    "L03": "chest",       # Chest
    "L04": "belt",        # Belt
    "L05": "lankle",      # Left ankle
    "L06": "pocket",      # Pocket
    "L07": "lforearm",    # Left forearm
    "L08": "lupperarm",   # Left upper arm
}

# ========== Step 1: Acquire & verify ==========
print("="*60)
print("Step 1: Data acquisition & verification")
print("="*60)

# Create directory structure
raw_dir = Path("data/lara/mbientlab/raw")
raw_dir.mkdir(parents=True, exist_ok=True)

# Find uploaded zip files (prefer annotated versions)
uploaded_files = list(Path(".").glob("*annotated*MbientLab*.zip"))
if not uploaded_files:
    uploaded_files = list(Path(".").glob("*MbientLab*.zip"))
if not uploaded_files:
    uploaded_files = list(Path(".").glob("*.zip"))

if not uploaded_files:
    raise FileNotFoundError("No MbientLab data archive found; please upload a zip file first")

if len(uploaded_files) > 1:
    print(f"Warning: found multiple candidate files: {[f.name for f in uploaded_files]}")
    print(f"Using the first: {uploaded_files[0].name}")

zip_file = uploaded_files[0]
print(f"Found archive: {zip_file}")

# Move to raw data directory
target_zip = raw_dir / zip_file.name
if not target_zip.exists():
    shutil.copy2(zip_file, target_zip)
    print(f"Copied to: {target_zip}")
else:
    print(f"File already exists: {target_zip}")

# Compute SHA256 checksum
print("Computing SHA256 checksum...")
sha256_hash = hashlib.sha256()
with open(target_zip, "rb") as f:
    for chunk in iter(lambda: f.read(65536), b""):
        sha256_hash.update(chunk)

checksum = sha256_hash.hexdigest()
print(f"SHA256: {checksum}")

# Save checksum
sha256_file = raw_dir / "SHA256SUMS.txt"
with open(sha256_file, "w") as f:
    f.write(f"{checksum}  {target_zip.name}\n")
print(f"Saved checksum: {sha256_file}")

# Record provenance (traceability)
provenance = {
    "dataset": "LARa IMU-only / MbientLab",
    "origin": "manual-upload",
    "official_url": "https://sensor.informatik.uni-mannheim.de/#dataset_lara",
    "retrieved_at_utc": datetime.now(timezone.utc).isoformat(),
    "archive": target_zip.name,
    "sha256": checksum
}
(raw_dir / "PROVENANCE.json").write_text(
    json.dumps(provenance, indent=2, ensure_ascii=False),
    encoding="utf-8"
)
print(f"Recorded provenance info: {raw_dir / 'PROVENANCE.json'}")

# Set raw archive to read-only
os.chmod(target_zip, 0o444)
print(f"Set read-only permission: {target_zip}")

# ========== Step 2: Unpack & directory standardization ==========
print("\n" + "="*60)
print("Step 2: Unpack & directory standardization")
print("="*60)

# Extract to temp directory
temp_extract = raw_dir / "temp_extract"
temp_extract.mkdir(exist_ok=True)

print(f"Extracting {target_zip.name}...")
with zipfile.ZipFile(target_zip, 'r') as zip_ref:
    zip_ref.extractall(temp_extract)

# Scan extracted files and normalize
file_records = []
problems = []  # record files that failed to parse

# Recursively scan all CSV/TSV files
for file_path in temp_extract.rglob("*"):
    if not file_path.is_file():
        continue

    # Process only data files
    if file_path.suffix.lower() not in ['.csv', '.tsv', '.txt']:
        continue

    # Parse filename: LARa pattern L01_S07_R01.csv
    filename = file_path.stem

    # Extract L01/L02/L03 (placement)
    placement_match = re.search(r'L(\d+)', filename)
    placement_raw = f"L{placement_match.group(1).zfill(2)}" if placement_match else "L00"
    placement = PLACEMENT_MAP.get(placement_raw, placement_raw)

    # Extract S07 (subject)
    subject_match = re.search(r'S(\d+)', filename)
    subject_id = f"S{subject_match.group(1).zfill(2)}" if subject_match else "S00"

    # Extract R01 (session)
    session_match = re.search(r'R(\d+)', filename)
    session_id = f"R{session_match.group(1).zfill(2)}" if session_match else "R01"

    # Detect parse failures (avoid LOSO leakage)
    if subject_id == "S00" or session_id == "R01":
        if not re.search(r'R01', filename):  # exclude real R01
            problems.append(str(file_path.relative_to(temp_extract)))

    # Create standardized directory structure
    std_dir = raw_dir / subject_id / session_id / placement
    std_dir.mkdir(parents=True, exist_ok=True)

    # Standardized filename (lowercase, underscores)
    std_filename = file_path.name.lower().replace(' ', '_').replace('-', '_')
    std_path = std_dir / std_filename

    # Copy to standardized location
    if not std_path.exists():
        shutil.copy2(file_path, std_path)

    # Get file info
    file_size = file_path.stat().st_size
    num_rows = 0
    sampling_rate = None
    duration = None
    sensor_type = "unknown"

    try:
        # Read sample
        df_sample = read_any_csv(file_path, nrows=2000)
        columns_lower = [c.lower() for c in df_sample.columns]

        # Infer sensor type
        sensor_type = infer_sensor_type(columns_lower, filename)

        # Infer sampling rate (skip for labels)
        if sensor_type != "labels":
            sampling_rate = infer_sampling_rate(df_sample)

        # Count total rows (streaming to avoid loading big files)
        with open(file_path, "rb") as fh:
            num_rows = sum(1 for _ in fh) - 1  # minus header

        # Compute duration
        if sampling_rate and num_rows > 0:
            duration = round(num_rows / sampling_rate, 2)

    except Exception:
        pass  # silently skip files that cannot be parsed

    # Record file info
    file_records.append({
        "subject_id": subject_id,
        "session_id": session_id,
        "placement": placement,
        "placement_raw": placement_raw,
        "sensor_type": sensor_type,
        "original_path": str(file_path.relative_to(temp_extract)),
        "standardized_path": str(std_path.relative_to(raw_dir)),
        "filename": std_filename,
        "file_size_bytes": file_size,
        "num_rows": num_rows,
        "sampling_rate_hz": sampling_rate,
        "duration_sec": duration,
    })

print(f"Processed {len(file_records)} files")

# Check parse failures
if problems:
    problems_file = raw_dir / "PROBLEMS.log"
    problems_file.write_text(
        "The following files could not parse subject/session (would break LOSO):\n" +
        "\n".join(problems) + "\n",
        encoding="utf-8"
    )
    raise RuntimeError(
        f"Found {len(problems)} files with unparsed subject/session; "
        f"please check {problems_file} and fix"
    )

# Remove temp extraction directory
shutil.rmtree(temp_extract)
print("Removed temporary files")

# Generate file_index (Parquet preferred; fallback to CSV)
if file_records:
    file_index = pd.DataFrame(file_records)

    # Sort
    file_index = file_index.sort_values(
        ['subject_id', 'session_id', 'placement', 'sensor_type']
    )

    # Save index
    index_file = raw_dir / "file_index.parquet"
    try:
        file_index.to_parquet(index_file, index=False)
        saved_index = index_file
        print(f"\nGenerated file index: {saved_index}")
    except Exception as e:
        print(f"Warning: Parquet write failed ({e}); falling back to CSV")
        index_file_csv = raw_dir / "file_index.csv"
        file_index.to_csv(index_file_csv, index=False)
        saved_index = index_file_csv
        print(f"Generated file index: {saved_index}")

    # Show dataset statistics
    print("\nDataset statistics:")
    print(f"  Number of subjects: {file_index['subject_id'].nunique()}")
    print(f"  Number of sessions: {file_index.groupby('subject_id')['session_id'].nunique().sum()}")
    print(f"  Placements: {sorted(file_index['placement'].unique().tolist())}")
    print(f"  Sensor types: {sorted(file_index['sensor_type'].unique().tolist())}")
    print(f"  Total files: {len(file_index)}")

    # Sampling rate stats
    sensor_files = file_index[file_index['sensor_type'] != 'labels']
    if not sensor_files.empty:
        rates = sensor_files['sampling_rate_hz'].dropna()
        if not rates.empty:
            print(f"  Sampling rate range: {rates.min():.1f} - {rates.max():.1f} Hz")
            print(f"  Median sampling rate: {rates.median():.1f} Hz")

    # Preview first records
    print("\nFile index preview:")
    print(file_index.head(10).to_string())
else:
    print("Warning: No data files found")

print("\n" + "="*60)
print("Steps 1–2 complete (top-conf/journal grade)")
print("="*60)
print(f"Raw data: {raw_dir}/")
print(f"Checksum: {sha256_file}")
print(f"Provenance record: {raw_dir / 'PROVENANCE.json'}")
print(f"File index: {saved_index}")
print("="*60)

Step 1: Data acquisition & verification
Found archive: IMU data (annotated) _ MbientLab.zip
Copied to: data/lara/mbientlab/raw/IMU data (annotated) _ MbientLab.zip
Computing SHA256 checksum...
SHA256: 70968b6b8874375e96671af67e31c27ccb63793f31191f86e732d40f24ac3106
Saved checksum: data/lara/mbientlab/raw/SHA256SUMS.txt
Recorded provenance info: data/lara/mbientlab/raw/PROVENANCE.json
Set read-only permission: data/lara/mbientlab/raw/IMU data (annotated) _ MbientLab.zip

Step 2: Unpack & directory standardization
Extracting IMU data (annotated) _ MbientLab.zip...
Processed 386 files
Removed temporary files

Generated file index: data/lara/mbientlab/raw/file_index.parquet

Dataset statistics:
  Number of subjects: 8
  Number of sessions: 193
  Placements: ['chest', 'lwrist', 'rwrist']
  Sensor types: ['acc+gyro', 'labels']
  Total files: 386

File index preview:
   subject_id session_id placement placement_raw sensor_type                                                original_path      

In [3]:
#!/usr/bin/env python3
"""
Step 3: Metadata & Quality Audit (top-conf/journal grade - final)
Parse subjects, activity set, sampling rate, placement, session time; empty-window cleanup
"""

import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime, timezone
import json
import re

# ========== Config ==========
MISSING_THRESHOLD = 0.05      # Missing-rate threshold 5%
GAP_THRESHOLD = 2.0           # Single-gap absolute threshold (seconds)
GAP_RATIO_THRESHOLD = 0.05    # Gap ratio threshold 5%

print("="*60)
print("Step 3: Metadata & Quality Audit")
print("="*60)

# Load file index
raw_dir = Path("data/lara/mbientlab/raw")
index_file = raw_dir / "file_index.parquet"
if not index_file.exists():
    index_file = raw_dir / "file_index.csv"

print(f"Loading file index: {index_file}")
file_index = pd.read_parquet(index_file) if index_file.suffix == '.parquet' else pd.read_csv(index_file)

# Initialize variables (avoid undefined in edge cases)
subject_agg = pd.DataFrame()
meta_subjects_file = None
meta_sessions_file = None
keep_sessions_file = None

# ========== Helper functions ==========
def pick_scale(med_raw, sr_hint=None):
    """Smartly pick time unit (s/ms/μs/ns → seconds)"""
    cands = [1.0, 1e-3, 1e-6, 1e-9]

    if sr_hint and sr_hint > 0:
        target_dt = 1.0 / sr_hint
        return min(cands, key=lambda s: abs(med_raw * s - target_dt))

    # Without hint: prefer median interval mapping into 5-400 Hz, bias toward ~50 Hz
    best, err = 1.0, float("inf")
    for s in cands:
        dt = med_raw * s
        if dt <= 0:
            continue
        sr = 1.0 / dt
        score = 0 if 5 <= sr <= 400 else abs(sr - 50) * 10
        if score < err:
            best, err = s, score
    return best

def extract_time_range_and_gaps(file_path, sampling_rate_hint=None, head_rows=20000, chunksize=200000):
    """Read time column in chunks; extract range and gaps (incl. inter-chunk gaps, memory-friendly)"""
    try:
        # Infer time column & unit from a small sample
        df_head = pd.read_csv(file_path, nrows=head_rows, sep=None, engine="python")
        time_cols = [c for c in df_head.columns if re.search(r"(time|timestamp|epoch|ts)", c, re.I)]
        if not time_cols:
            return None, None, 0.0, 0.0, 0.0

        c = time_cols[0]
        s = pd.to_numeric(df_head[c], errors="coerce").dropna().to_numpy()

        # Numeric timestamp branch
        if s.size >= 3:
            diffs = np.diff(s)
            diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
            if diffs.size > 0:
                med = float(np.median(diffs))
                scale = pick_scale(med, sampling_rate_hint)
                expected = (1.0 / sampling_rate_hint) if (sampling_rate_hint and sampling_rate_hint > 0) else (med * scale)

                # OR logic: two independent thresholds
                rel_threshold = 10.0 * expected  # Relative threshold: 10× expected interval
                abs_threshold = GAP_THRESHOLD    # Absolute threshold: 2 s

                first = None
                last = None
                prev = None
                gap_sec = 0.0
                max_gap = 0.0

                for chunk in pd.read_csv(file_path, usecols=[c], sep=None, engine="python", chunksize=chunksize):
                    v = pd.to_numeric(chunk[c], errors="coerce").dropna().to_numpy()
                    if v.size == 0:
                        continue

                    if first is None:
                        first = v[0]

                    # Inter-chunk gaps (fix: use max as baseline)
                    if prev is not None:
                        delta = (v[0] - prev) * scale
                        cond_rel = delta > rel_threshold
                        cond_abs = delta > abs_threshold

                        if cond_rel or cond_abs:
                            # If both trigger, use max (more lenient); if only one, use that one
                            if cond_rel and cond_abs:
                                base = max(rel_threshold, abs_threshold)
                            elif cond_rel:
                                base = rel_threshold
                            else:
                                base = abs_threshold

                            gap_this = delta - base
                            gap_sec += gap_this
                            max_gap = max(max_gap, gap_this)

                    # Intra-chunk gaps (fix: shape + baseline)
                    d = np.diff(v) * scale
                    mask_rel = d > rel_threshold
                    mask_abs = d > abs_threshold
                    mask = mask_rel | mask_abs

                    if mask.any():
                        # Vectorized: choose the threshold triggered by each gap (use max if both)
                        both_triggered = mask_rel & mask_abs
                        thr_used = np.where(
                            both_triggered,
                            max(rel_threshold, abs_threshold),
                            np.where(mask_rel, rel_threshold, abs_threshold)
                        )
                        gaps = d[mask] - thr_used[mask]  # Fix: also index thr_used
                        gap_sec += float(gaps.sum())
                        max_gap = max(max_gap, float(gaps.max()))

                    prev = v[-1]
                    last = v[-1]

                if first is not None and last is not None:
                    start_sec = float(first * scale)
                    end_sec = float(last * scale)
                    total = end_sec - start_sec
                    ratio = float(gap_sec / total) if total > 0 else 0.0
                    return start_sec, end_sec, float(round(gap_sec, 2)), float(round(ratio, 4)), float(round(max_gap, 2))

        # Fallback branch: datetime strings
        t_head = pd.to_datetime(df_head[c], utc=True, errors="coerce").dropna()
        if t_head.size >= 3:
            med = float(t_head.diff().dt.total_seconds().dropna().median())
            if med > 0:
                expected = (1.0 / sampling_rate_hint) if (sampling_rate_hint and sampling_rate_hint > 0) else med

                # OR logic
                rel_threshold = 10.0 * expected
                abs_threshold = GAP_THRESHOLD

                first = None
                last = None
                prev = None
                gap_sec = 0.0
                max_gap = 0.0

                for chunk in pd.read_csv(file_path, usecols=[c], sep=None, engine="python", chunksize=chunksize):
                    tt = pd.to_datetime(chunk[c], utc=True, errors="coerce").dropna()
                    if tt.empty:
                        continue

                    if first is None:
                        first = tt.iloc[0]

                    # Inter-chunk gaps (fix: use max as baseline)
                    if prev is not None:
                        delta = (tt.iloc[0] - prev).total_seconds()
                        cond_rel = delta > rel_threshold
                        cond_abs = delta > abs_threshold

                        if cond_rel or cond_abs:
                            if cond_rel and cond_abs:
                                base = max(rel_threshold, abs_threshold)
                            elif cond_rel:
                                base = rel_threshold
                            else:
                                base = abs_threshold

                            gap_this = delta - base
                            gap_sec += gap_this
                            max_gap = max(max_gap, gap_this)

                    # Intra-chunk gaps (fix: shape + baseline)
                    d = tt.diff().dt.total_seconds().dropna()
                    mask_rel = d > rel_threshold
                    mask_abs = d > abs_threshold
                    mask = mask_rel | mask_abs

                    if not mask.empty and mask.any():
                        both_triggered = mask_rel & mask_abs
                        thr_used = np.where(
                            both_triggered,
                            max(rel_threshold, abs_threshold),
                            np.where(mask_rel, rel_threshold, abs_threshold)
                        )
                        gaps = d[mask].values - thr_used[mask]
                        gap_sec += float(gaps.sum())
                        max_gap = max(max_gap, float(gaps.max()))

                    prev = tt.iloc[-1]
                    last = tt.iloc[-1]

                if first is not None and last is not None:
                    total = (last - first).total_seconds()
                    ratio = float(gap_sec / total) if total > 0 else 0.0
                    return first.timestamp(), last.timestamp(), float(round(gap_sec, 2)), float(round(ratio, 4)), float(round(max_gap, 2))

        return None, None, 0.0, 0.0, 0.0

    except Exception:
        return None, None, 0.0, 0.0, 0.0

def safe_float(x, default=0.0):
    """Safely cast to float, handling NaN/Inf"""
    try:
        if x is None or (isinstance(x, float) and (np.isnan(x) or np.isinf(x))):
            return default
        return float(x)
    except:
        return default

# ========== 1. Parse sensor data metadata ==========
print("\n" + "="*60)
print("1. Parse sensor data metadata")
print("="*60)

# Determine label files directly from filenames (more reliable)
label_files = file_index[
    file_index['filename'].str.contains('label', case=False, na=False)
].copy()
sensor_files = file_index[
    ~file_index['filename'].str.contains('label', case=False, na=False)
].copy()

print(f"Sensor files: {len(sensor_files)}")
print(f"Label files: {len(label_files)}")

# Extract time ranges for sensor files (receive 5 return values)
print("Extracting time spans and gap statistics (chunked)...")
time_records = []
for idx, row in sensor_files.iterrows():
    file_path = raw_dir / row['standardized_path']
    start, end, gap_sec, gap_ratio, max_gap = extract_time_range_and_gaps(
        file_path,
        row['sampling_rate_hz']
    )
    time_records.append({
        'subject_id': row['subject_id'],
        'session_id': row['session_id'],
        'placement': row['placement'],
        'start_time': start,
        'end_time': end,
        'gap_seconds': gap_sec,
        'gap_ratio': gap_ratio,
        'max_gap_seconds': max_gap,
    })

df_time_ranges = pd.DataFrame(time_records)

# Aggregate time ranges by session (includes max_gap)
session_time_agg = df_time_ranges.groupby(['subject_id', 'session_id']).agg({
    'start_time': 'min',
    'end_time': 'max',
    'gap_seconds': 'sum',
    'max_gap_seconds': 'max',
}).reset_index()

session_time_agg['session_duration_sec'] = (
    session_time_agg['end_time'] - session_time_agg['start_time']
)
session_time_agg['gap_ratio'] = (
    session_time_agg['gap_seconds'] / session_time_agg['session_duration_sec']
).fillna(0.0).infer_objects(copy=False)

session_time_agg.rename(columns={
    'start_time': 'session_start_time',
    'end_time': 'session_end_time'
}, inplace=True)

# Add ISO8601 (human-readable) times
def to_iso(x):
    try:
        if pd.notna(x):
            return datetime.fromtimestamp(float(x), tz=timezone.utc).isoformat()
    except:
        pass
    return None

session_time_agg['session_start_utc'] = session_time_agg['session_start_time'].apply(to_iso)
session_time_agg['session_end_utc'] = session_time_agg['session_end_time'].apply(to_iso)

print(f"Extracted time spans for {len(session_time_agg)} sessions")

# ========== 2. Parse labels & activity statistics ==========
print("\n" + "="*60)
print("2. Parse labels & activity statistics")
print("="*60)

activity_stats = []
session_records = []

for idx, label_row in label_files.iterrows():
    label_path = raw_dir / label_row['standardized_path']

    if not label_path.exists():
        continue

    try:
        # Read label file
        df_label = pd.read_csv(label_path, sep=None, engine='python')

        # Find label column (LARa dataset uses 'Class')
        if 'Class' in df_label.columns:
            label_col = 'Class'
        elif 'class' in df_label.columns:
            label_col = 'class'
        else:
            label_cols = [c for c in df_label.columns if 'label' in c.lower() or 'activity' in c.lower()]
            if not label_cols:
                print(f"  No label column ({df_label.columns.tolist()}): {label_path.name}")
                continue
            label_col = label_cols[0]

        # Count activity distribution
        activity_counts = df_label[label_col].value_counts()
        total_samples = len(df_label)

        # Check missing
        missing_count = df_label[label_col].isna().sum()
        missing_rate = missing_count / total_samples if total_samples > 0 else 0

        # Record session info
        session_info = {
            'subject_id': label_row['subject_id'],
            'session_id': label_row['session_id'],
            'placement': label_row['placement'],
            'total_samples': total_samples,
            'missing_samples': missing_count,
            'missing_rate': round(missing_rate, 4),
            'num_activities': len(activity_counts),
        }

        # Add per-activity stats
        for activity, count in activity_counts.items():
            activity_stats.append({
                'subject_id': label_row['subject_id'],
                'session_id': label_row['session_id'],
                'placement': label_row['placement'],
                'activity': str(activity),
                'count': int(count),
                'percentage': round(count / total_samples * 100, 2)
            })

        session_records.append(session_info)

    except Exception as e:
        print(f"  Warning: failed to parse {label_path.name}: {e}")
        continue

print(f"Parsed {len(session_records)} sessions")

# ========== 2.1 Orphan session check ==========
print("\nChecking orphan sessions...")
sess_from_sensors = set(zip(sensor_files['subject_id'], sensor_files['session_id']))
sess_from_labels = set(zip(label_files['subject_id'], label_files['session_id']))
orphans = sess_from_sensors - sess_from_labels

if orphans:
    orphan_file = raw_dir / "QA_ISSUES.log"
    with open(orphan_file, "a", encoding="utf-8") as f:
        f.write("\nSessions with sensors but no labels (orphan sessions):\n")
        for s, r in sorted(orphans):
            f.write(f"  {s}-{r}\n")
    print(f"⚠️  Found {len(orphans)} orphan sessions; logged to QA_ISSUES.log")

# ========== 3. Merge session metadata ==========
print("\n" + "="*60)
print("3. Merge session metadata")
print("="*60)

df_sessions = pd.DataFrame(session_records)
df_activities = pd.DataFrame(activity_stats)

# Merge time info
if not df_sessions.empty and not session_time_agg.empty:
    df_sessions = df_sessions.merge(
        session_time_agg,
        on=['subject_id', 'session_id'],
        how='left'
    )
    print(f"Merged time span info")

# ========== 4. Data quality checks & empty-window cleanup ==========
print("\n" + "="*60)
print("4. Data quality checks & empty-window cleanup")
print("="*60)

if not df_sessions.empty:
    # Generate keep flag
    df_sessions['keep'] = True
    df_sessions['reject_reason'] = ''

    # Check missing-rate exceeds threshold
    high_missing_mask = df_sessions['missing_rate'] > MISSING_THRESHOLD
    if high_missing_mask.any():
        df_sessions.loc[high_missing_mask, 'keep'] = False
        df_sessions.loc[high_missing_mask, 'reject_reason'] = 'high_missing_rate'
        print(f"⚠️  {high_missing_mask.sum()} sessions marked not kept due to high missing rate")

    # Check time-gap ratio exceeds threshold
    if 'gap_ratio' in df_sessions.columns:
        high_gap_mask = df_sessions['gap_ratio'] > GAP_RATIO_THRESHOLD
        if high_gap_mask.any():
            # Append reason if already rejected; otherwise mark alone
            for idx in df_sessions[high_gap_mask].index:
                if df_sessions.loc[idx, 'keep']:
                    df_sessions.loc[idx, 'keep'] = False
                    df_sessions.loc[idx, 'reject_reason'] = 'high_gap_ratio'
                else:
                    df_sessions.loc[idx, 'reject_reason'] += '+high_gap_ratio'
            print(f"⚠️  {high_gap_mask.sum()} sessions marked not kept due to high gap ratio")

    # Summary
    keep_count = df_sessions['keep'].sum()
    reject_count = (~df_sessions['keep']).sum()
    print(f"✓ QC result: keep {keep_count} sessions, reject {reject_count} sessions")

    # Save keep list
    keep_sessions_file = raw_dir / "qa_keep_sessions.csv"
    df_sessions[['subject_id', 'session_id', 'placement', 'keep', 'reject_reason',
                 'missing_rate', 'gap_ratio']].to_csv(keep_sessions_file, index=False)
    print(f"✓ Saved: {keep_sessions_file}")

    # Log rejection details
    if reject_count > 0:
        rejected = df_sessions[~df_sessions['keep']]
        qa_issues = raw_dir / "QA_ISSUES.log"
        with open(qa_issues, "a") as f:
            f.write(f"\nSessions rejected by QC (total {reject_count}):\n\n")
            f.write(rejected[['subject_id', 'session_id', 'placement', 'reject_reason',
                             'missing_rate', 'gap_ratio']].to_string(index=False))
        print(f"  Details logged to: {qa_issues}")

# ========== 4.1 Generate file-level empty-window list ==========
print("\nGenerating file-level empty-window list...")
if not df_time_ranges.empty:
    empty_segments = df_time_ranges[
        df_time_ranges['gap_ratio'].notna() &
        (df_time_ranges['gap_ratio'] > GAP_RATIO_THRESHOLD)
    ].copy()

    if not empty_segments.empty:
        empty_todo_file = raw_dir / "EMPTY_SEGMENTS_TODO.csv"
        empty_segments[['subject_id', 'session_id', 'placement',
                       'gap_seconds', 'gap_ratio', 'max_gap_seconds']].to_csv(empty_todo_file, index=False)
        print(f"⚠️  Generated empty-segment list: {empty_todo_file} ({len(empty_segments)} files)")

# ========== 5. Generate subject-level metadata ==========
print("\n" + "="*60)
print("5. Generate subject-level metadata")
print("="*60)

if not df_sessions.empty:
    # Only count kept sessions
    df_keep = df_sessions[df_sessions['keep']]

    if not df_keep.empty:
        # Aggregate by subject
        subject_agg = df_keep.groupby('subject_id').agg({
            'session_id': 'nunique',
            'total_samples': 'sum',
            'missing_samples': 'sum',
            'session_duration_sec': 'sum',
            'num_activities': 'sum',
        }).reset_index()

        subject_agg.columns = ['subject_id', 'num_sessions', 'total_samples',
                               'total_missing', 'total_duration_sec', 'total_activities']

        # Compute overall missing rate
        subject_agg['overall_missing_rate'] = (
            subject_agg['total_missing'] / subject_agg['total_samples']
        ).round(4)

        # Add placement coverage
        placement_coverage = df_keep.groupby('subject_id')['placement'].apply(
            lambda x: ','.join(sorted(set(x)))
        ).reset_index()
        placement_coverage.columns = ['subject_id', 'placements']

        subject_agg = subject_agg.merge(placement_coverage, on='subject_id')

        # Save subject metadata
        meta_subjects_file = raw_dir / "meta_subjects.csv"
        subject_agg.to_csv(meta_subjects_file, index=False)
        print(f"✓ Saved: {meta_subjects_file}")
        print(f"  Number of subjects: {len(subject_agg)}")

# ========== 6. Generate session-level metadata ==========
print("\n" + "="*60)
print("6. Generate session-level metadata")
print("="*60)

if not df_sessions.empty:
    # Add activity list
    if not df_activities.empty:
        activity_list = df_activities.groupby(['subject_id', 'session_id'])['activity'].apply(
            lambda x: ','.join(sorted(set(x)))
        ).reset_index()
        activity_list.columns = ['subject_id', 'session_id', 'activities']

        df_sessions_full = df_sessions.merge(
            activity_list,
            on=['subject_id', 'session_id'],
            how='left'
        )
    else:
        df_sessions_full = df_sessions

    # Save session metadata
    meta_sessions_file = raw_dir / "meta_sessions.csv"
    df_sessions_full.to_csv(meta_sessions_file, index=False)
    print(f"✓ Saved: {meta_sessions_file}")
    print(f"  Number of sessions: {len(df_sessions_full)}")

# ========== 7. Generate quality audit report ==========
print("\n" + "="*60)
print("7. Generate quality audit report")
print("="*60)

qa_report = []
qa_report.append("="*70)
qa_report.append("LARa MbientLab IMU Dataset - Quality Audit Report")
qa_report.append("="*70)
qa_report.append(f"Generated at: {datetime.now(timezone.utc).isoformat()}")
qa_report.append(f"Data path: {raw_dir}")
qa_report.append("")

# Overall stats
qa_report.append("[1. Dataset overview]")
qa_report.append("-"*70)
if not subject_agg.empty:
    total_hours = safe_float(subject_agg['total_duration_sec'].sum() / 3600)
    qa_report.append(f"Number of subjects: {len(subject_agg)}")
    qa_report.append(f"Total sessions: {subject_agg['num_sessions'].sum()}")
    qa_report.append(f"Total duration: {total_hours:.2f} hours")
    qa_report.append(f"Total samples: {subject_agg['total_samples'].sum():,}")
qa_report.append("")

# Sampling rate stats
qa_report.append("[2. Sampling rate statistics]")
qa_report.append("-"*70)
if not sensor_files.empty:
    rates = sensor_files['sampling_rate_hz'].dropna()
    if not rates.empty:
        qa_report.append(f"Sampling rate range: {rates.min():.2f} - {rates.max():.2f} Hz")
        qa_report.append(f"Median sampling rate: {rates.median():.2f} Hz")
        qa_report.append(f"Mode sampling rate: {rates.mode().values[0]:.2f} Hz")
qa_report.append("")

# Placement coverage
qa_report.append("[3. Sensor placement coverage]")
qa_report.append("-"*70)
if not df_sessions.empty:
    df_keep = df_sessions[df_sessions['keep']]
    if not df_keep.empty:
        placement_dist = df_keep['placement'].value_counts()
        for placement, count in placement_dist.items():
            percentage = count / len(df_keep) * 100
            qa_report.append(f"  {placement:15s}: {count:3d} sessions ({percentage:5.1f}%)")
qa_report.append("")

# Activity distribution
qa_report.append("[4. Activity distribution]")
qa_report.append("-"*70)
if not df_activities.empty:
    activity_total = df_activities.groupby('activity').agg({
        'count': 'sum',
    }).sort_values('count', ascending=False)

    total_count = activity_total['count'].sum()
    qa_report.append(f"Number of activity classes: {len(activity_total)}")
    qa_report.append(f"Total samples: {total_count:,}")
    qa_report.append("")
    qa_report.append("Per-activity share:")
    for activity, row in activity_total.iterrows():
        percentage = row['count'] / total_count * 100
        qa_report.append(f"  {str(activity):30s}: {row['count']:8,} ({percentage:5.2f}%)")
qa_report.append("")

# Data quality (incl. max_gap stats)
qa_report.append("[5. Data quality assessment]")
qa_report.append("-"*70)
if not df_sessions.empty:
    qa_report.append(f"Missing-rate threshold: {MISSING_THRESHOLD*100}%")
    qa_report.append(f"Gap absolute threshold: {GAP_THRESHOLD} s")
    qa_report.append(f"Gap relative threshold: 10× expected interval")
    qa_report.append(f"Gap ratio threshold: {GAP_RATIO_THRESHOLD*100}%")

    avg_miss = safe_float(df_sessions['missing_rate'].mean())
    max_miss = safe_float(df_sessions['missing_rate'].max())
    med_miss = safe_float(df_sessions['missing_rate'].median())

    qa_report.append(f"Overall average missing rate: {avg_miss*100:.2f}%")
    qa_report.append(f"Max missing rate: {max_miss*100:.2f}%")
    qa_report.append(f"Median missing rate: {med_miss*100:.2f}%")

    if 'gap_ratio' in df_sessions.columns:
        avg_gap = safe_float(df_sessions['gap_ratio'].mean())
        max_gap_ratio = safe_float(df_sessions['gap_ratio'].max())
        qa_report.append(f"Average gap ratio: {avg_gap*100:.2f}%")
        qa_report.append(f"Max gap ratio: {max_gap_ratio*100:.2f}%")

    if 'max_gap_seconds' in df_sessions.columns:
        max_single_gap = safe_float(df_sessions['max_gap_seconds'].max())
        qa_report.append(f"Max single gap: {max_single_gap:.2f} s")

    keep_count = df_sessions['keep'].sum()
    total_count = len(df_sessions)
    pass_rate = keep_count / total_count * 100 if total_count > 0 else 0
    qa_report.append(f"")
    qa_report.append(f"Sessions passing QC: {keep_count}/{total_count} ({pass_rate:.1f}%)")

if (raw_dir / "EMPTY_SEGMENTS_TODO.csv").exists():
    qa_report.append("")
    qa_report.append("[Note] Empty/abnormal segments found; see: EMPTY_SEGMENTS_TODO.csv (exclude during later sliding-window segmentation)")

qa_report.append("")

# Per-subject details
qa_report.append("[6. Subject-level details]")
qa_report.append("-"*70)
if not subject_agg.empty:
    for _, subj in subject_agg.iterrows():
        qa_report.append(f"Subject {subj['subject_id']}:")
        qa_report.append(f"  # sessions: {subj['num_sessions']}")
        qa_report.append(f"  Total duration: {subj['total_duration_sec']/60:.1f} minutes")
        qa_report.append(f"  Total samples: {subj['total_samples']:,}")
        qa_report.append(f"  Missing rate: {subj['overall_missing_rate']*100:.2f}%")
        qa_report.append(f"  Placements: {subj['placements']}")
        qa_report.append("")

qa_report.append("="*70)
qa_report.append("End of report")
qa_report.append("="*70)

# Save QA report
qa_report_file = raw_dir / "QA_REPORT.txt"
with open(qa_report_file, "w", encoding="utf-8") as f:
    f.write("\n".join(qa_report))

print(f"✓ Saved quality report: {qa_report_file}")

# Also print to console
print("\n" + "\n".join(qa_report))

# ========== 8. Generate summary JSON ==========
summary = {
    "generated_at_utc": datetime.now(timezone.utc).isoformat(),
    "num_subjects": int(len(subject_agg)) if not subject_agg.empty else 0,
    "num_sessions_total": len(df_sessions) if not df_sessions.empty else 0,
    "num_sessions_keep": int(df_sessions['keep'].sum()) if not df_sessions.empty else 0,
    "total_duration_hours": safe_float(subject_agg['total_duration_sec'].sum() / 3600) if not subject_agg.empty else 0.0,
    "missing_threshold": MISSING_THRESHOLD,
    "gap_threshold_sec": GAP_THRESHOLD,
    "gap_ratio_threshold": GAP_RATIO_THRESHOLD,
    "avg_missing_rate": safe_float(df_sessions['missing_rate'].mean()) if not df_sessions.empty else 0.0,
    "avg_gap_ratio": safe_float(df_sessions['gap_ratio'].mean()) if not df_sessions.empty and 'gap_ratio' in df_sessions.columns else 0.0,
    "max_single_gap_seconds": safe_float(df_sessions['max_gap_seconds'].max()) if not df_sessions.empty and 'max_gap_seconds' in df_sessions.columns else 0.0,
    "num_activities": int(len(activity_total)) if not df_activities.empty else 0,
    "placements": sorted(df_sessions[df_sessions['keep']]['placement'].unique().tolist()) if not df_sessions.empty and df_sessions['keep'].any() else [],
}

summary_file = raw_dir / "qa_summary.json"
with open(summary_file, "w") as f:
    json.dump(summary, f, indent=2)

print(f"\n✓ Saved summary: {summary_file}")

print("\n" + "="*60)
print("Step 3 complete - Metadata & Quality Audit (top-conf/journal grade)")
print("="*60)
print(f"Output files:")
if meta_subjects_file:
    print(f"  - {meta_subjects_file}")
if meta_sessions_file:
    print(f"  - {meta_sessions_file}")
if keep_sessions_file:
    print(f"  - {keep_sessions_file}")
print(f"  - {qa_report_file}")
print(f"  - {summary_file}")
if (raw_dir / "EMPTY_SEGMENTS_TODO.csv").exists():
    print(f"  - {raw_dir / 'EMPTY_SEGMENTS_TODO.csv'} (file-level empty-window list)")
if (raw_dir / "QA_ISSUES.log").exists():
    print(f"  - {raw_dir / 'QA_ISSUES.log'} (quality issue details)")
print("="*60)

Step 3: Metadata & Quality Audit
Loading file index: data/lara/mbientlab/raw/file_index.parquet

1. Parse sensor data metadata
Sensor files: 193
Label files: 193
Extracting time spans and gap statistics (chunked)...
Extracted time spans for 193 sessions

2. Parse labels & activity statistics
Parsed 193 sessions

Checking orphan sessions...

3. Merge session metadata
Merged time span info

4. Data quality checks & empty-window cleanup
✓ QC result: keep 193 sessions, reject 0 sessions
✓ Saved: data/lara/mbientlab/raw/qa_keep_sessions.csv

Generating file-level empty-window list...

5. Generate subject-level metadata
✓ Saved: data/lara/mbientlab/raw/meta_subjects.csv
  Number of subjects: 8

6. Generate session-level metadata
✓ Saved: data/lara/mbientlab/raw/meta_sessions.csv
  Number of sessions: 193

7. Generate quality audit report
✓ Saved quality report: data/lara/mbientlab/raw/QA_REPORT.txt

LARa MbientLab IMU Dataset - Quality Audit Report
Generated at: 2025-11-15T12:54:20.443564+00

In [4]:
#!/usr/bin/env python3
"""
Step 4: Channel & Placement Strategy Selection (top-conf/journal grade)
Select placement, raw channels, derived channels; generate config file
"""

import pandas as pd
import numpy as np
from pathlib import Path
import yaml
import re

print("="*60)
print("Step 4: Channel & Placement Strategy Selection")
print("="*60)

# ========== Placement → Prefix allowlist (eradicate cross-placement leakage) ==========
PREFIX_ALLOWLIST = {
    "rwrist": ["RA_"],
    "lwrist": ["LA_"],
    "chest":  ["N_"],
    # Extensible: "rleg": ["RL_"], "lleg": ["LL_"]
}

REQ_SUFFIX = {
    "ax": "AccelerometerX", "ay": "AccelerometerY", "az": "AccelerometerZ",
    "gx": "GyroscopeX",     "gy": "GyroscopeY",     "gz": "GyroscopeZ",
}

# Coverage threshold: required column presence ratio across files (1.0=100%, 0.95=95%)
MIN_COVERAGE = 1.0

# Load metadata
raw_dir = Path("data/lara/mbientlab/raw")
configs_dir = Path("configs")
configs_dir.mkdir(parents=True, exist_ok=True)

# Load subject metadata
meta_subjects = pd.read_csv(raw_dir / "meta_subjects.csv")
print(f"\nLoaded subject metadata: {len(meta_subjects)} subjects")

# Load file index
index_file = raw_dir / "file_index.parquet"
if not index_file.exists():
    index_file = raw_dir / "file_index.csv"
file_index = pd.read_parquet(index_file) if index_file.suffix == '.parquet' else pd.read_csv(index_file)

# Keep only sensor files (more robust: filter by sensor_type and filename)
if 'sensor_type' in file_index.columns:
    sensor_files = file_index[
        (file_index['sensor_type'].isin(['acc+gyro', 'acc', 'gyro'])) &
        ~file_index['filename'].str.contains('label', case=False, na=False)
    ].copy()
else:
    sensor_files = file_index[
        ~file_index['filename'].str.contains('label', case=False, na=False)
    ].copy()

print(f"Number of sensor files: {len(sensor_files)}")

# ========== 1. Analyze placement coverage ==========
print("\n" + "="*60)
print("1. Analyze placement coverage")
print("="*60)

# Count data volume per placement
placement_stats = sensor_files.groupby('placement').agg({
    'subject_id': 'nunique',
    'session_id': 'nunique',
    'file_size_bytes': 'sum',
    'num_rows': 'sum',
}).reset_index()
placement_stats.columns = ['placement', 'num_subjects', 'num_sessions', 'total_bytes', 'total_samples']
placement_stats = placement_stats.sort_values('total_samples', ascending=False)

print("\nPlacement statistics (sorted by sample count):")
print(placement_stats.to_string(index=False))

# Fix selection to right wrist (this round)
selected_placement = "rwrist"
print(f"\nFixed placement for this round: {selected_placement}")

# Check whether placement exists
if selected_placement not in placement_stats['placement'].values:
    raise ValueError(f"Specified placement '{selected_placement}' does not exist in the data")

# Check which subjects have that placement
subjects_with_selected = sensor_files[sensor_files['placement'] == selected_placement]['subject_id'].unique()
print(f"Subjects with {selected_placement} data: {len(subjects_with_selected)}/{len(meta_subjects)}")

# ========== 2. Allowlist validation & channel check ==========
print("\n" + "="*60)
print("2. Allowlist validation & channel check")
print("="*60)

# Read only from files of selected placement
placement_files = sensor_files[sensor_files['placement'] == selected_placement]
print(f"Number of files for selected placement '{selected_placement}': {len(placement_files)}")

# Get allowlist prefixes
allowed_prefixes = PREFIX_ALLOWLIST.get(selected_placement, [])
assert allowed_prefixes, f"Prefix allowlist for '{selected_placement}' not configured; please add it in PREFIX_ALLOWLIST"
print(f"\nUsing placement→prefix allowlist: {selected_placement} → {allowed_prefixes}")

# Robust header-reading function
def read_cols(fp):
    """Read column names (with fallback)"""
    try:
        return pd.read_csv(fp, nrows=5, sep=None, engine='python').columns.tolist()
    except Exception:
        return pd.read_csv(fp, nrows=5, sep=",").columns.tolist()

# Read headers of all files
print(f"\nRead headers of all {len(placement_files)} files to check consistency...")
all_columns_by_file = []

for _, row in placement_files.iterrows():
    fp = raw_dir / row['standardized_path']
    cols = read_cols(fp)
    data_cols = [c for c in cols if not re.search(r'(time|timestamp|epoch|index|id|class|label)', c, re.I)]
    all_columns_by_file.append(data_cols)

# Assert all files were read successfully
assert len(all_columns_by_file) == len(placement_files), \
    f"{len(placement_files)-len(all_columns_by_file)} '{selected_placement}' files failed header reading; fix or exclude these files first"

print(f"✓ Successfully read {len(all_columns_by_file)} files")

# Show columns of the first file as a reference
if all_columns_by_file:
    print(f"\nData columns of the first file:")
    for col in all_columns_by_file[0]:
        print(f"  {col}")

# ========== 3. Build strict channel mapping (allowlist + consistency assertions) ==========
print("\n" + "="*60)
print("3. Build strict channel mapping (allowlist + consistency assertions)")
print("="*60)

def extract_prefix(col):
    """Extract column prefix"""
    m = re.match(r'^([A-Z]{1,}_)', col)
    return m.group(1) if m else None

def build_mapping_from_allowlist(allowed_prefixes, all_cols_by_file, min_coverage=1.0):
    """Compose column names from allowlist × suffix and check coverage"""
    mapping = {}
    missing_files = {}

    for std, suf in REQ_SUFFIX.items():
        chosen = None
        for pfx in allowed_prefixes:
            cand = f"{pfx}{suf}"
            # Count in how many files this column exists
            present_files = [i for i, cols in enumerate(all_cols_by_file) if cand in cols]
            coverage = len(present_files) / len(all_cols_by_file)

            if coverage >= min_coverage:
                chosen = cand
                if coverage < 1.0:
                    # Record indices of files missing this column (for later inspection)
                    missing_idx = [i for i in range(len(all_cols_by_file)) if i not in present_files]
                    missing_files[std] = missing_idx
                break

        if not chosen:
            raise RuntimeError(
                f"[Consistency assertion failed] {std}: Under prefixes {allowed_prefixes}, no '{suf}' meets {min_coverage*100:.0f}% coverage. "
                f"Check raw column names or change placement/prefix allowlist."
            )

        mapping[std] = chosen

    # Prefix consistency check: all mapped columns must come from allowlist
    used_prefixes = {extract_prefix(v) for v in mapping.values()}
    if not used_prefixes.issubset(set(allowed_prefixes)):
        raise RuntimeError(
            f"[Consistency assertion failed] Final mapping prefixes {used_prefixes} are not all within allowlist {allowed_prefixes}"
        )

    return mapping, used_prefixes, missing_files

# Build mapping
final_mapping, used_prefixes, missing_files = build_mapping_from_allowlist(
    allowed_prefixes, all_columns_by_file, MIN_COVERAGE
)

print("\nFinal channel mapping (standard_name <- original_column):")
for std, orig in sorted(final_mapping.items()):
    print(f"  {std} <- {orig}")

# Explicit hard assertions
assert len(used_prefixes) == 1, f"A single prefix should be used; got {used_prefixes}"
assert list(used_prefixes)[0] in set(PREFIX_ALLOWLIST[selected_placement]), \
    f"Source prefix {used_prefixes} not in allowlist {PREFIX_ALLOWLIST[selected_placement]} for {selected_placement}"

print(f"\n✓ Consistency assertions passed:")
print(f"  - Using a single prefix: {sorted(used_prefixes)}")
print(f"  - Prefix is in the allowlist: {PREFIX_ALLOWLIST[selected_placement]}")
print(f"  - Number of files checked: {len(all_columns_by_file)}")
print(f"  - Coverage requirement: {MIN_COVERAGE*100:.0f}%")

# If there are missing, print warnings
if missing_files:
    print(f"\n⚠️  The following channels are missing in some files (coverage threshold set to {MIN_COVERAGE*100:.0f}%):")
    for std, idx_list in missing_files.items():
        print(f"  {std}: missing in {len(idx_list)} files")

# ========== 4. Generate channel & placement config ==========
print("\n" + "="*60)
print("4. Generate channel & placement config")
print("="*60)

# Config content
config = {
    'dataset': 'LARa_MbientLab_IMU',
    'strategy': 'single_placement_baseline',

    # Placement configuration
    'placements': {
        'selected': [selected_placement],
        'available': placement_stats['placement'].tolist(),
        'rationale': f'Fixed selection {selected_placement}, covering {len(subjects_with_selected)} subjects',
    },

    # Raw channel configuration
    'channels': {
        'raw': ['ax', 'ay', 'az', 'gx', 'gy', 'gz'],
        'mapping': final_mapping,
        'prefix_allowlist': PREFIX_ALLOWLIST,
        'source_prefix': sorted(used_prefixes)[0],
        'min_coverage': MIN_COVERAGE,
        'description': {
            'ax': 'Accelerometer X axis (m/s² or g)',
            'ay': 'Accelerometer Y axis (m/s² or g)',
            'az': 'Accelerometer Z axis (m/s² or g)',
            'gx': 'Gyroscope X axis (rad/s or deg/s)',
            'gy': 'Gyroscope Y axis (rad/s or deg/s)',
            'gz': 'Gyroscope Z axis (rad/s or deg/s)',
        }
    },

    # Derived channel configuration
    'derived_channels': {
        'acc_mag': {
            'formula': 'sqrt(ax^2 + ay^2 + az^2)',
            'description': 'Accelerometer vector magnitude',
        },
        'gyr_mag': {
            'formula': 'sqrt(gx^2 + gy^2 + gz^2)',
            'description': 'Gyroscope vector magnitude',
        }
    },

    # Final channel order
    'final_channels': ['ax', 'ay', 'az', 'gx', 'gy', 'gz', 'acc_mag', 'gyr_mag'],

    # Multi-placement fusion (reserved; currently disabled)
    'multi_placement_fusion': {
        'enabled': False,
        'strategy': None,
        'warning': 'If enabling multi-placement fusion, you must select the fusion strategy independently within each training fold to avoid cross-fold leakage',
    },

    # Rigor notes
    'notes': [
        'Single-placement baseline: avoid cross-placement information leakage',
        'Channel mapping uses "placement→prefix allowlist + consistency assertions"; no cross-prefix voting',
        f'Consistency checked over all {len(all_columns_by_file)} {selected_placement} files',
        f'Coverage requirement: {MIN_COVERAGE*100:.0f}% (tunable tolerance)',
        'Derived channels are computed at feature-extraction stage to preserve raw data integrity',
        'Any multi-placement fusion must be chosen & validated within each LOSO fold',
    ]
}

# Save config
config_file = configs_dir / "channels.yaml"
with open(config_file, 'w', encoding='utf-8') as f:
    yaml.dump(config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)

print(f"✓ Saved config: {config_file}")

# ========== 5. Validate config (random multi-file sampling) ==========
print("\n" + "="*60)
print("5. Validate config")
print("="*60)

# Verify coverage of selected placement across all sensor files
files_with_placement = sensor_files[sensor_files['placement'] == selected_placement]

print(f"\nValidate placement '{selected_placement}':")
print(f"  Files: {len(files_with_placement)}")
print(f"  Subjects: {files_with_placement['subject_id'].nunique()}")
print(f"  Sessions: {files_with_placement['session_id'].nunique()}")

# Validate channel mapping: randomly sample multiple files
verify_sample_size = min(5, len(files_with_placement))
verify_df = files_with_placement.sample(n=verify_sample_size, random_state=0)

print(f"\nValidate channel mapping (random sample of {verify_sample_size} files):")
for idx, sample_file in verify_df.iterrows():
    sample_path = raw_dir / sample_file['standardized_path']
    try:
        df_verify = pd.read_csv(sample_path, nrows=100, sep=None, engine='python')

        print(f"\nFile: {sample_file['filename']}")
        all_found = True
        for std_name in ['ax', 'ay', 'az', 'gx', 'gy', 'gz']:
            if std_name in final_mapping:
                orig_name = final_mapping[std_name]
                if orig_name in df_verify.columns:
                    sample_val = df_verify[orig_name].iloc[0]
                    print(f"  ✓ {std_name} <- {orig_name} (sample value: {sample_val:.4f})")
                else:
                    print(f"  ✗ {std_name} <- {orig_name} (column not found)")
                    all_found = False
            else:
                print(f"  ✗ {std_name} (not mapped)")
                all_found = False

        if not all_found:
            print(f"  ⚠️  This file failed validation")

    except Exception as e:
        print(f"\nFile: {sample_file['filename']}")
        print(f"  ✗ Error during validation: {e}")

# Compute derived-channel examples on the first successfully validated file
for idx, sample_file in verify_df.iterrows():
    sample_path = raw_dir / sample_file['standardized_path']
    try:
        df_verify = pd.read_csv(sample_path, nrows=100, sep=None, engine='python')
        if all(final_mapping[ch] in df_verify.columns for ch in ['ax', 'ay', 'az', 'gx', 'gy', 'gz']):
            acc_mag = np.sqrt(
                df_verify[final_mapping['ax']].values**2 +
                df_verify[final_mapping['ay']].values**2 +
                df_verify[final_mapping['az']].values**2
            )
            gyr_mag = np.sqrt(
                df_verify[final_mapping['gx']].values**2 +
                df_verify[final_mapping['gy']].values**2 +
                df_verify[final_mapping['gz']].values**2
            )

            print(f"\nDerived-channel example values (file: {sample_file['filename']}):")
            print(f"  acc_mag: min={acc_mag.min():.4f}, max={acc_mag.max():.4f}, mean={acc_mag.mean():.4f}")
            print(f"  gyr_mag: min={gyr_mag.min():.4f}, max={gyr_mag.max():.4f}, mean={gyr_mag.mean():.4f}")
            break
    except:
        continue

# ========== 6. Fuse check (reload config for verification) ==========
print("\n" + "="*60)
print("6. Fuse check (reload config for verification)")
print("="*60)

with open(config_file, "r", encoding="utf-8") as f:
    cfg = yaml.safe_load(f)

# Extract prefixes of all mapped columns
srcs = list(cfg["channels"]["mapping"].values())
pfxs = {re.match(r'^([A-Za-z]+_)', s).group(1) for s in srcs if re.match(r'^([A-Za-z]+_)', s)}

# Assertion: all channels use the same prefix
assert len(pfxs) == 1, f"ax..gz not using a single prefix: {pfxs}"

# Assertion: prefix in allowlist
sel = cfg["placements"]["selected"][0]
allow = set(cfg["channels"]["prefix_allowlist"][sel])
assert list(pfxs)[0] in allow, f"Prefix {pfxs} not in {sel} allowlist {allow}"

print(f"✓ Config fuse check passed:")
print(f"  - Reloaded config: {config_file}")
print(f"  - All channels use a single prefix: {pfxs}")
print(f"  - Prefix is in {sel} allowlist: {allow}")

# ========== 7. Summary ==========
print("\n" + "="*60)
print("Step 4 complete - Channels & Placement Strategy")
print("="*60)
print(f"\nConfig summary:")
print(f"  Strategy: single-placement baseline")
print(f"  Fixed placement: {config['placements']['selected']}")
print(f"  Raw channels: {config['channels']['raw']}")
print(f"  Derived channels: {list(config['derived_channels'].keys())}")
print(f"  Final number of channels: {len(config['final_channels'])}")
print(f"  Prefix used: {sorted(used_prefixes)}")
print(f"  Coverage requirement: {MIN_COVERAGE*100:.0f}%")
print(f"\nConfig file: {config_file}")
print(f"\nRigor guarantees:")
print(f"  1. ✓ Use placement→prefix allowlist (hard-coded)")
print(f"  2. ✓ Consistency assertions across all files ({len(all_columns_by_file)} files)")
print(f"  3. ✓ No cross-prefix voting; avoid mis-selection")
print(f"  4. ✓ Error out if column names don't match allowlist")
print(f"  5. ✓ Explicit assertions: single prefix + within allowlist")
print(f"  6. ✓ Abort if header reading fails")
print(f"  7. ✓ Randomly sample {verify_sample_size} files to validate mapping")
print(f"  8. ✓ Fuse check: reload config and verify prefix")
print("="*60)

Step 4: Channel & Placement Strategy Selection

Loaded subject metadata: 8 subjects
Number of sensor files: 193

1. Analyze placement coverage

Placement statistics (sorted by sample count):
placement  num_subjects  num_sessions  total_bytes  total_samples
   rwrist             8            14    685467785        1120045
    chest             7            14    595662725         972496
   lwrist             6             2     80626579         131911

Fixed placement for this round: rwrist
Subjects with rwrist data: 8/8

2. Allowlist validation & channel check
Number of files for selected placement 'rwrist': 96

Using placement→prefix allowlist: rwrist → ['RA_']

Read headers of all 96 files to check consistency...
✓ Successfully read 96 files

Data columns of the first file:
  LA_AccelerometerX
  LA_AccelerometerY
  LA_AccelerometerZ
  LA_GyroscopeX
  LA_GyroscopeY
  LA_GyroscopeZ
  LL_AccelerometerX
  LL_AccelerometerY
  LL_AccelerometerZ
  LL_GyroscopeX
  LL_GyroscopeY
  LL_Gyroscop

In [1]:
import os

"""
Step 5: Timeline Unification & Resampling (top-conf/journal grade - flawless)
Unify to 50 Hz; linear interpolation/forward-fill; align start/end
"""

import pandas as pd
import numpy as np
from pathlib import Path
import yaml
import re
import json

# ========== Config ==========
TARGET_FREQ_HZ = 50.0           # Target sampling rate
MAX_INTERP_GAP_MS = 20.0        # Maximum interpolation gap (milliseconds)
MAX_INTERP_RATIO = 0.15         # Gap coverage threshold 15% (constant; applied globally)

print("="*60)
print("Step 5: Timeline Unification & Resampling")
print("="*60)

# Load config and metadata
raw_dir = Path("data/lara/mbientlab/raw")
proc_dir = Path("data/lara/mbientlab/proc")
proc_dir.mkdir(parents=True, exist_ok=True)

configs_dir = Path("configs")
with open(configs_dir / "channels.yaml", 'r', encoding='utf-8') as f:
    channel_config = yaml.safe_load(f)

selected_placement = channel_config['placements']['selected'][0]
channel_mapping = channel_config['channels']['mapping']
print(f"\nTarget sampling rate: {TARGET_FREQ_HZ} Hz")
print(f"Selected placement: {selected_placement}")

# Load QC results (all kept sessions)
qa_keep = pd.read_csv(raw_dir / "qa_keep_sessions.csv")
keep_sessions = qa_keep[qa_keep['keep'] == True].copy()
keep_sessions = keep_sessions[keep_sessions['placement'] == selected_placement].copy()

# Global processing (no per-fold dependency)
print(f"\nGlobal resampling over all kept sessions (no per-fold markers)")
print(f"  Total sessions: {len(keep_sessions)}")

# Prune switch: always ON (remove sessions with excessive gaps globally)
APPLY_PRUNE = True
keep_sessions['is_train'] = False  # Kept for compatibility with stats / logs

# Load file index
index_file = raw_dir / "file_index.parquet"
if not index_file.exists():
    index_file = raw_dir / "file_index.csv"
file_index = pd.read_parquet(index_file) if index_file.suffix == '.parquet' else pd.read_csv(index_file)

# ========== Helper functions ==========
def detect_time_column(df):
    """Detect time column (avoid false positive matches on 'ts' substring)"""
    time_cols = [c for c in df.columns
                 if re.search(r'(^|_)(time|timestamp|epoch|ts)($|_)', c, re.I)]
    return time_cols[0] if time_cols else None

def parse_time_to_seconds(time_series):
    """Convert time to seconds (correctly infer Unix timestamp units)"""
    numeric = pd.to_numeric(time_series, errors='coerce')
    if numeric.notna().sum() > len(time_series) * 0.9:
        vals = numeric.dropna().values
        max_val = np.abs(vals[:1000]).max() if len(vals) else 0

        # Infer by 2025 Unix timestamp magnitude
        if max_val > 1e17:      # nanoseconds
            return numeric * 1e-9
        elif max_val > 1e14:    # microseconds
            return numeric * 1e-6
        elif max_val > 1e11:    # milliseconds
            return numeric * 1e-3
        else:                   # seconds
            return numeric

    dt = pd.to_datetime(time_series, utc=True, errors='coerce')
    if dt.notna().sum() > len(time_series) * 0.9:
        epoch = pd.Timestamp("1970-01-01", tz='UTC')
        return (dt - epoch).dt.total_seconds()

    return None

def resample_sensor_data(df, time_col, data_cols, target_freq_hz=50.0, max_gap_ms=20.0):
    """Resample sensor data (return cleaned time for labels)"""
    time_sec = parse_time_to_seconds(df[time_col])
    if time_sec is None:
        raise ValueError("Unable to parse time column")

    valid_mask = time_sec.notna() & df[data_cols].notna().all(axis=1)
    time_clean = time_sec[valid_mask].values
    data_clean = df.loc[valid_mask, data_cols].values

    if len(time_clean) < 2:
        return None, 0.0, 0, 0.0, 0.0, None

    # De-duplicate + sort
    unique_idx = np.unique(time_clean, return_index=True)[1]
    time_clean = time_clean[unique_idx]
    data_clean = data_clean[unique_idx]

    order = np.argsort(time_clean)
    time_clean = time_clean[order]
    data_clean = data_clean[order]

    # Original frequency
    dt_orig = np.median(np.diff(time_clean))
    orig_freq_hz = 1.0 / dt_orig if dt_orig > 0 else 0.0

    # Build target timeline with integer number of samples
    dt = 1.0 / target_freq_hz
    t_start = time_clean[0]
    t_end = time_clean[-1]
    n_samples = int(np.round((t_end - t_start) / dt))
    target_time = t_start + np.arange(n_samples + 1) * dt

    # Linear interpolation
    resampled_data = np.zeros((len(target_time), len(data_cols)))
    for i in range(len(data_cols)):
        resampled_data[:, i] = np.interp(target_time, time_clean, data_clean[:, i])

    # Large-gap detection (account for jitter)
    max_gap_sec = max(max_gap_ms / 1000.0, 1.25 * dt)
    time_diffs = np.diff(time_clean)
    gap_mask = time_diffs > max_gap_sec

    is_in_gap = np.zeros(len(target_time), dtype=int)
    is_forced_nan = np.zeros(len(target_time), dtype=int)
    actual_interp_count = 0
    total_gap_time = 0.0

    if gap_mask.any():
        for i in range(len(time_clean) - 1):
            if gap_mask[i]:
                t_gap_start = time_clean[i]
                t_gap_end = time_clean[i + 1]
                gap_duration = t_gap_end - t_gap_start
                total_gap_time += gap_duration

                idxs = np.where((target_time > t_gap_start) & (target_time < t_gap_end))[0]

                if idxs.size > 0:
                    is_in_gap[idxs] = 1
                    actual_interp_count += 1

                    if idxs.size > 1:
                        forced_nan_idxs = idxs[1:]
                        is_forced_nan[forced_nan_idxs] = 1
                        resampled_data[forced_nan_idxs, :] = np.nan

    # Gap coverage
    gap_points = int(is_in_gap.sum())
    interp_ratio = gap_points / len(target_time) if len(target_time) > 0 else 0.0

    # Gap time fraction
    total_duration = t_end - t_start
    gap_time_fraction = total_gap_time / total_duration if total_duration > 0 else 0.0

    resampled_df = pd.DataFrame(resampled_data, columns=data_cols)
    resampled_df.insert(0, 'time_sec', target_time)
    resampled_df['is_in_gap'] = is_in_gap
    resampled_df['is_forced_nan'] = is_forced_nan

    return resampled_df, interp_ratio, gap_points, gap_time_fraction, orig_freq_hz, time_clean

def resample_labels(df_label, df_sensor_time_clean, label_col, target_time, label_time_col=None):
    """Resample labels (boundary NaN + sorting)"""
    if label_time_col is not None:
        time_sec = parse_time_to_seconds(df_label[label_time_col])
        if time_sec is None:
            raise ValueError("Unable to parse label time column")

        valid_mask = time_sec.notna() & df_label[label_col].notna()
        time_clean = time_sec[valid_mask].values
        labels_clean = df_label.loc[valid_mask, label_col].values
    else:
        # Use cleaned sensor time as reference
        sensor_time_original = df_sensor_time_clean
        if sensor_time_original is None:
            raise ValueError("Labels have no time column and no sensor time provided")

        min_len = min(len(df_label), len(sensor_time_original))
        if abs(len(df_label) - len(sensor_time_original)) > min_len * 0.01:
            raise ValueError(
                f"Label rows ({len(df_label)}) differ too much from sensor rows ({len(sensor_time_original)})"
            )

        time_clean = sensor_time_original[:min_len]
        labels_clean = df_label[label_col].iloc[:min_len].values

        valid_mask = pd.notna(labels_clean)
        time_clean = time_clean[valid_mask]
        labels_clean = labels_clean[valid_mask]

    if len(time_clean) == 0:
        return np.full(len(target_time), np.nan)

    # Explicit sorting
    order = np.argsort(time_clean)
    time_clean = time_clean[order]
    labels_clean = labels_clean[order]

    idx = np.searchsorted(time_clean, target_time, side='right') - 1
    idx = np.clip(idx, 0, len(time_clean) - 1)

    labels = labels_clean[idx].copy()

    # Fix: cast integers to float to allow NaN
    if labels.dtype.kind in ['i', 'u']:  # integer or unsigned integer
        labels = labels.astype('float64')

    # Boundary NaNs
    mask_before = target_time < time_clean[0]
    mask_after = target_time > time_clean[-1]
    labels[mask_before | mask_after] = np.nan

    return labels

# ========== 1. Process all sessions ==========
print("\n" + "="*60)
print("1. Resampling")
print("="*60)

resampled_records = []
interp_stats = []
issues = []

for idx, session in keep_sessions.iterrows():
    subject_id = session['subject_id']
    session_id = session['session_id']
    placement = session['placement']
    is_train = session['is_train']

    print(f"\nProcessing {subject_id}/{session_id}/{placement} {'[TRAIN]' if is_train else '[TEST]'}...")

    sensor_file = file_index[
        (file_index['subject_id'] == subject_id) &
        (file_index['session_id'] == session_id) &
        (file_index['placement'] == placement) &
        (~file_index['filename'].str.contains('label', case=False, na=False))
    ]

    label_file = file_index[
        (file_index['subject_id'] == subject_id) &
        (file_index['session_id'] == session_id) &
        (file_index['placement'] == placement) &
        (file_index['filename'].str.contains('label', case=False, na=False))
    ]

    if sensor_file.empty or label_file.empty:
        print(f"  Skip: missing files")
        continue

    sensor_path = raw_dir / sensor_file.iloc[0]['standardized_path']
    label_path = raw_dir / label_file.iloc[0]['standardized_path']

    try:
        df_sensor = pd.read_csv(sensor_path, sep=None, engine='python')
        time_col = detect_time_column(df_sensor)
        if not time_col:
            print(f"  Skip: no time column")
            continue

        data_cols = [channel_mapping[std] for std in ['ax', 'ay', 'az', 'gx', 'gy', 'gz']]
        missing_cols = [c for c in data_cols if c not in df_sensor.columns]
        if missing_cols:
            print(f"  Skip: missing columns {missing_cols}")
            continue

        print(f"  Resampling sensors ({len(df_sensor)} rows)...")
        result = resample_sensor_data(
            df_sensor, time_col, data_cols, TARGET_FREQ_HZ, MAX_INTERP_GAP_MS
        )

        if result[0] is None:
            print(f"  Skip: resampling failed")
            continue

        # Receive cleaned time for labels
        resampled_sensor, interp_ratio, gap_points, gap_time_frac, orig_freq, sensor_time_clean = result

        valid_samples = resampled_sensor[data_cols].notna().all(axis=1).sum()
        nan_samples = len(resampled_sensor) - valid_samples
        forced_nan_points = int(resampled_sensor['is_forced_nan'].sum())

        print(f"  → {len(resampled_sensor)} rows, gap coverage: {interp_ratio*100:.2f}%, NaN: {nan_samples}")

        # Prune based on global switch
        if interp_ratio > MAX_INTERP_RATIO:
            msg = f"Gap coverage too high ({interp_ratio*100:.1f}%)"
            print(f"  ⚠️  {msg}")
            issues.append({
                'subject_id': subject_id,
                'session_id': session_id,
                'placement': placement,
                'is_train': is_train,
                'issue': 'high_gap_coverage',
                'gap_coverage': round(interp_ratio, 4),
            })
            if APPLY_PRUNE:
                continue

        interp_stats.append({
            'subject_id': subject_id,
            'session_id': session_id,
            'placement': placement,
            'is_train': is_train,
            'original_samples': len(df_sensor),
            'original_freq_hz': round(orig_freq, 2),
            'resampled_samples': len(resampled_sensor),
            'valid_samples': valid_samples,
            'nan_samples': nan_samples,
            'gap_points': gap_points,
            'gap_coverage': round(interp_ratio, 4),
            'gap_time_fraction': round(gap_time_frac, 4),
            'forced_nan_points': forced_nan_points,
        })

        df_label = pd.read_csv(label_path, sep=None, engine='python')

        label_col = None
        for col_candidate in ['Class', 'class', 'label', 'Label', 'activity', 'Activity']:
            if col_candidate in df_label.columns:
                label_col = col_candidate
                break

        if not label_col:
            for col in df_label.columns:
                if any(kw in col.lower() for kw in ['label', 'activity', 'class', 'action']):
                    label_col = col
                    break

        if not label_col:
            print(f"  Skip: no label column")
            issues.append({
                'subject_id': subject_id,
                'session_id': session_id,
                'placement': placement,
                'is_train': is_train,
                'issue': 'no_label_column',
            })
            continue

        label_time_col = detect_time_column(df_label)
        target_time = resampled_sensor['time_sec'].values

        print(f"  Resampling labels...")
        try:
            if label_time_col:
                resampled_labels = resample_labels(
                    df_label, sensor_time_clean, label_col, target_time,
                    label_time_col=label_time_col
                )
            else:
                resampled_labels = resample_labels(
                    df_label, sensor_time_clean, label_col, target_time
                )

            resampled_sensor['label'] = resampled_labels

        except Exception as e:
            print(f"  Skip: label resampling failed - {e}")
            issues.append({
                'subject_id': subject_id,
                'session_id': session_id,
                'placement': placement,
                'is_train': is_train,
                'issue': 'label_resample_error',
                'error': str(e),  # include error details
            })
            continue

        resampled_sensor.rename(columns={
            channel_mapping['ax']: 'ax',
            channel_mapping['ay']: 'ay',
            channel_mapping['az']: 'az',
            channel_mapping['gx']: 'gx',
            channel_mapping['gy']: 'gy',
            channel_mapping['gz']: 'gz',
        }, inplace=True)

        resampled_sensor.insert(0, 'subject_id', subject_id)
        resampled_sensor.insert(1, 'session_id', session_id)
        resampled_sensor.insert(2, 'placement', placement)

        resampled_records.append(resampled_sensor)
        print(f"  ✓ Done")

    except Exception as e:
        print(f"  ✗ Error: {e}")
        issues.append({
            'subject_id': subject_id,
            'session_id': session_id,
            'placement': placement,
            'is_train': is_train,
            'issue': 'processing_error',
            'error': str(e),  # include error details
        })

print(f"\nSuccessfully processed: {len(resampled_records)} sessions")
print(f"Skipped/failed: {len(issues)} sessions")

# ========== 2. Combine & save ==========
print("\n" + "="*60)
print("2. Combine & Save")
print("="*60)

if resampled_records:
    df_all = pd.concat(resampled_records, ignore_index=True)

    # Optimization: cast dtypes (reduce size)
    for c in ['ax', 'ay', 'az', 'gx', 'gy', 'gz']:
        df_all[c] = df_all[c].astype('float32')
    df_all['time_sec'] = df_all['time_sec'].astype('float64')  # Keep high precision for time

    output_file = proc_dir / "resampled.parquet"

    if output_file.exists():
        import shutil
        if output_file.is_dir():
            shutil.rmtree(output_file)
        else:
            output_file.unlink()
        print(f"Removed old data: {output_file}")

    df_all.to_parquet(
        output_file,
        index=False,
        partition_cols=['subject_id', 'placement'],
        engine='pyarrow'
    )
    print(f"✓ Saved: {output_file}")
    print(f"  Total rows: {len(df_all):,}")
    print(f"  # subjects: {df_all['subject_id'].nunique()}")
    print(f"  # sessions: {df_all.groupby(['subject_id', 'session_id']).ngroups}")

    valid_mask = df_all[['ax', 'ay', 'az', 'gx', 'gy', 'gz']].notna().all(axis=1)
    print(f"  Valid samples: {valid_mask.sum():,} ({valid_mask.sum()/len(df_all)*100:.1f}%)")
    print(f"  Samples with NaN: {(~valid_mask).sum():,}")

    print("\nData preview:")
    print(df_all.head(10).to_string())

    print("\nNumeric column stats (valid samples):")
    numeric_cols = ['ax', 'ay', 'az', 'gx', 'gy', 'gz']
    print(df_all.loc[valid_mask, numeric_cols].describe().round(4))
else:
    print("Warning: No data to save")

# ========== 3. Save statistics ==========
if interp_stats:
    df_interp = pd.DataFrame(interp_stats)
    interp_file = proc_dir / "resample_stats.csv"
    df_interp.to_csv(interp_file, index=False)
    print(f"\n✓ Saved stats: {interp_file}")

    if 'is_train' in df_interp.columns and df_interp['is_train'].any():
        train_stats = df_interp[df_interp['is_train']]
        print(f"\nGap statistics (train fold):")
        print(f"  Mean gap coverage: {train_stats['gap_coverage'].mean()*100:.2f}%")
        print(f"  Max gap coverage: {train_stats['gap_coverage'].max()*100:.2f}%")
        print(f"  Mean gap time fraction: {train_stats['gap_time_fraction'].mean()*100:.2f}%")

        print(f"\nGap statistics (overall):")
        print(f"  Mean gap coverage: {df_interp['gap_coverage'].mean()*100:.2f}%")
        print(f"  Max gap coverage: {df_interp['gap_coverage'].max()*100:.2f}%")
    else:
        print(f"\nGap statistics:")
        print(f"  Mean gap coverage: {df_interp['gap_coverage'].mean()*100:.2f}%")
        print(f"  Max gap coverage: {df_interp['gap_coverage'].max()*100:.2f}%")

if issues:
    df_issues = pd.DataFrame(issues)
    issues_file = proc_dir / "resample_issues.csv"
    df_issues.to_csv(issues_file, index=False)
    print(f"\n⚠️  Saved issue records: {issues_file} ({len(issues)} items)")

print("\n" + "="*60)
print("Step 5 complete - Flawless version")
print("="*60)
print(f"\nFinal fixes:")
print(f"  1. ✓ Prune switch (always ON; global high-gap sessions removed)")
print(f"  2. ✓ Label time harmonized (reuse cleaned time)")
print(f"  3. ✓ Complete error information (\"error\" field)")
print(f"  4. ✓ Comment fix (constant threshold 0.15)")
print(f"  5. ✓ Type optimization (float32/float64)")
print("="*60)

Step 5: Timeline Unification & Resampling

Target sampling rate: 50.0 Hz
Selected placement: rwrist

Global resampling over all kept sessions (no per-fold markers)
  Total sessions: 96

1. Resampling

Processing S07/R03/rwrist [TEST]...
  Resampling sensors (11758 rows)...
  → 5879 rows, gap coverage: 0.00%, NaN: 0
  Resampling labels...
  ✓ Done

Processing S07/R05/rwrist [TEST]...
  Resampling sensors (11766 rows)...
  → 5883 rows, gap coverage: 0.00%, NaN: 0
  Resampling labels...
  ✓ Done

Processing S07/R06/rwrist [TEST]...
  Resampling sensors (11838 rows)...
  → 5919 rows, gap coverage: 0.00%, NaN: 0
  Resampling labels...
  ✓ Done

Processing S07/R07/rwrist [TEST]...
  Resampling sensors (11795 rows)...
  → 5898 rows, gap coverage: 0.00%, NaN: 0
  Resampling labels...
  ✓ Done

Processing S07/R08/rwrist [TEST]...
  Resampling sensors (11804 rows)...
  → 5902 rows, gap coverage: 0.00%, NaN: 0
  Resampling labels...
  ✓ Done

Processing S07/R09/rwrist [TEST]...
  Resampling senso

In [2]:
import os

"""
Step 6: Sensor Preprocessing (top-conf/journal grade - final fixed version)
Accelerometer high-pass to remove gravity; gyroscope denoising; adaptive ±Nσ clipping (target 1%)
Global version: no FOLD_ID required; thresholds estimated on all data.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import yaml
import json
from scipy import signal

# ========== Config ==========
# Accelerometer high-pass (remove gravity)
ACC_HPF_CUTOFF_HZ = 0.3      # Cutoff frequency
ACC_HPF_ORDER = 2            # Filter order

# Gyroscope low-pass (denoise)
GYR_LPF_CUTOFF_HZ = 20.0     # Cutoff frequency
GYR_LPF_ORDER = 2            # Filter order

# Adaptive clipping threshold (auto-tuned to target clipping rate)
TARGET_CLIP_RATE = 0.01      # Target clipping rate 1% (sum of both tails)

# Sampling rate (from Step 5)
SAMPLING_RATE_HZ = 50.0

# Unit conversions
DEG2RAD = np.pi / 180.0
G_TO_MS2 = 9.80665

print("="*60)
print("Step 6: Sensor Preprocessing")
print("="*60)

# Load data
proc_dir = Path("data/lara/mbientlab/proc")
configs_dir = Path("configs")

print(f"\nLoading resampled data: {proc_dir / 'resampled.parquet'}")
df = pd.read_parquet(proc_dir / "resampled.parquet")

print(f"Data shape: {df.shape}")
print(f"Number of subjects: {df['subject_id'].nunique()}")
print(f"Number of sessions: {df.groupby(['subject_id', 'session_id'], observed=True).ngroups}")

# ========== 0. Unit normalization ==========
print("\n" + "="*60)
print("0. Unit normalization")
print("="*60)

acc_channels = ['ax', 'ay', 'az']
print(f"\nAccelerometer unit conversion: g → m/s²")
for ch in acc_channels:
    if ch in df.columns:
        mask = df[ch].notna()
        df.loc[mask, ch] = df.loc[mask, ch] * G_TO_MS2
print(f"✓ Conversion factor: {G_TO_MS2:.5f}")

gyr_channels = ['gx', 'gy', 'gz']
print(f"\nGyroscope unit conversion: deg/s → rad/s")
for ch in gyr_channels:
    if ch in df.columns:
        mask = df[ch].notna()
        df.loc[mask, ch] = df.loc[mask, ch] * DEG2RAD
print(f"✓ Conversion factor: π/180 = {DEG2RAD:.6f}")

# ========== Helper functions ==========
def design_highpass_filter(cutoff_hz, fs_hz, order=2):
    """Design a high-pass Butterworth filter"""
    nyq = 0.5 * fs_hz
    normal_cutoff = cutoff_hz / nyq
    b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
    return b, a

def design_lowpass_filter(cutoff_hz, fs_hz, order=2):
    """Design a low-pass Butterworth filter"""
    nyq = 0.5 * fs_hz
    normal_cutoff = cutoff_hz / nyq
    b, a = signal.butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

def filtfilt_nan_safe(x, b, a):
    """Zero-phase filtering tolerant to NaN (filter each contiguous non-NaN run)"""
    y = x.copy()
    good = np.isfinite(x)

    if not good.any():
        return x

    idx = np.where(good)[0]
    cuts = np.where(np.diff(idx) > 1)[0] + 1
    runs = np.split(idx, cuts)

    padlen = 3 * (max(len(a), len(b)) - 1)

    for run in runs:
        seg = x[run]

        if len(seg) > padlen:
            y[run] = signal.filtfilt(b, a, seg, method="pad")
        else:
            tmp = signal.lfilter(b, a, seg)
            y[run] = signal.lfilter(b, a, tmp[::-1])[::-1]

    return y

def apply_filter_by_session(df, channels, b, a):
    """Apply zero-phase filtering grouped by session (include placement grouping + sorting)"""
    filtered_data = []

    for (subj, sess, plc), group in df.groupby(['subject_id', 'session_id', 'placement'], observed=True):
        group = group.sort_values('time_sec').copy()

        for ch in channels:
            if ch not in group.columns:
                continue

            data = group[ch].values
            filtered = filtfilt_nan_safe(data, b, a)
            group[ch] = filtered

        filtered_data.append(group)

    return pd.concat(filtered_data, ignore_index=True)

def compute_clip_thresholds_target(df, channels, target_rate=0.01, use_robust=True):
    """Adaptive thresholds to a target clipping rate (Scheme A)

    Args:
        target_rate: target total clipping rate for both tails (e.g., 0.01 = 1%)
        use_robust: if True, use Median±k·(1.4826·MAD); otherwise Mean±k·Std
    """
    eps = 1e-6
    thresholds = {}

    for ch in channels:
        if ch not in df.columns:
            continue

        x = df[ch].dropna().values
        if x.size == 0:
            continue

        if use_robust:
            # Robust estimate: Median ± k·(1.4826·MAD)
            median = np.median(x)
            mad = np.median(np.abs(x - median))
            robust_std = max(1.4826 * mad, eps)

            deviations = np.abs(x - median) / robust_std
            k = np.quantile(deviations, 1 - target_rate)

            lower = median - k * robust_std
            upper = median + k * robust_std

            thresholds[ch] = {
                'center': float(median),
                'scale': float(robust_std),
                'k': float(k),
                'lower': float(lower),
                'upper': float(upper),
                'method': f'Median±k·MAD (k={k:.3f}, both tails total {target_rate*100:.1f}%)',
            }
        else:
            # Conventional estimate: Mean ± k·Std
            mean = np.mean(x)
            std = max(np.std(x), eps)

            deviations = np.abs(x - mean) / std
            k = np.quantile(deviations, 1 - target_rate)

            lower = mean - k * std
            upper = mean + k * std

            thresholds[ch] = {
                'center': float(mean),
                'scale': float(std),
                'k': float(k),
                'lower': float(lower),
                'upper': float(upper),
                'method': f'Mean±k·Std (k={k:.3f}, both tails total {target_rate*100:.1f}%)',
            }

    return thresholds

def apply_clip(df, channels, thresholds):
    """Apply clipping and compute actual clipping rate"""
    df_clipped = df.copy()
    clip_stats = {}

    for ch in channels:
        if ch not in df_clipped.columns or ch not in thresholds:
            continue

        lower = thresholds[ch]['lower']
        upper = thresholds[ch]['upper']

        mask = df_clipped[ch].notna()
        total = mask.sum()

        if total > 0:
            outliers = ((df_clipped.loc[mask, ch] < lower) | (df_clipped.loc[mask, ch] > upper)).sum()
            clip_rate = outliers / total
            clip_stats[ch] = {
                'outliers': int(outliers),
                'total': int(total),
                'rate': float(clip_rate),
            }

        df_clipped.loc[mask, ch] = df_clipped.loc[mask, ch].clip(lower, upper)

    return df_clipped, clip_stats

# ========== 1. Design filters ==========
print("\n" + "="*60)
print("1. Design filters")
print("="*60)

print(f"\nAccelerometer high-pass filter:")
print(f"  Cutoff frequency: {ACC_HPF_CUTOFF_HZ} Hz")
print(f"  Order: {ACC_HPF_ORDER}")
acc_b, acc_a = design_highpass_filter(ACC_HPF_CUTOFF_HZ, SAMPLING_RATE_HZ, ACC_HPF_ORDER)

print(f"\nGyroscope low-pass filter:")
print(f"  Cutoff frequency: {GYR_LPF_CUTOFF_HZ} Hz")
print(f"  Order: {GYR_LPF_ORDER}")
gyr_b, gyr_a = design_lowpass_filter(GYR_LPF_CUTOFF_HZ, SAMPLING_RATE_HZ, GYR_LPF_ORDER)

# ========== 2. Apply filters (by session + placement) ==========
print("\n" + "="*60)
print("2. Apply filters (by session + placement, zero-phase)")
print("="*60)

print("\nApplying accelerometer high-pass (remove gravity)...")
df_filtered = apply_filter_by_session(df, acc_channels, acc_b, acc_a)
print("✓ Done")

print("\nApplying gyroscope low-pass (denoise)...")
df_filtered = apply_filter_by_session(df_filtered, gyr_channels, gyr_b, gyr_a)
print("✓ Done")

# ========== 3. Compute clipping thresholds (adaptive to target rate) ==========
print("\n" + "="*60)
print("3. Compute adaptive clipping thresholds (target clipping rate)")
print("="*60)

print("Estimate clipping thresholds on all data")
print(f"  Target clip rate: {TARGET_CLIP_RATE*100:.1f}%")

all_channels = acc_channels + gyr_channels
df_for_stats = df_filtered  # global estimation on all subjects
clip_thresholds = compute_clip_thresholds_target(
    df_for_stats, all_channels, TARGET_CLIP_RATE, use_robust=True
)

print(f"\nClipping thresholds (adaptive robust estimation):")
for ch, thresh in clip_thresholds.items():
    print(f"  {ch}:")
    print(f"    center: {thresh['center']:.4f}")
    print(f"    scale: {thresh['scale']:.4f}")
    print(f"    k: {thresh['k']:.3f}")
    print(f"    range: [{thresh['lower']:.4f}, {thresh['upper']:.4f}]")

# ========== 4. Apply clipping ==========
print("\n" + "="*60)
print("4. Apply adaptive clipping")
print("="*60)

df_clipped, clip_stats = apply_clip(df_filtered, all_channels, clip_thresholds)

print("\nActual clipping statistics:")
for ch, stats in clip_stats.items():
    print(f"  {ch}: {stats['outliers']:,} / {stats['total']:,} ({stats['rate']*100:.2f}%)")

# ========== 5. Cast to float32 to save memory ==========
print("\n" + "="*60)
print("5. Data type optimization")
print("="*60)

numeric_cols = ['ax', 'ay', 'az', 'gx', 'gy', 'gz']
for col in numeric_cols:
    if col in df_clipped.columns:
        df_clipped[col] = df_clipped[col].astype('float32')

print(f"✓ Sensor columns cast to float32")
print(f"✓ time_sec kept as float64")

# ========== 6. Save results ==========
print("\n" + "="*60)
print("6. Save results")
print("="*60)

output_file = proc_dir / "filtered.parquet"

if output_file.exists():
    import shutil
    if output_file.is_dir():
        shutil.rmtree(output_file)
    else:
        output_file.unlink()
    print(f"Removed old data: {output_file}")

df_clipped.to_parquet(
    output_file,
    index=False,
    partition_cols=['subject_id', 'placement'],
    engine='pyarrow'
)
print(f"✓ Saved: {output_file}")
print(f"  Data shape: {df_clipped.shape}")

print("\nData preview:")
print(df_clipped.head(10).to_string())

print("\nPost-filter numeric column stats:")
valid_mask = df_clipped[numeric_cols].notna().all(axis=1)
print(df_clipped.loc[valid_mask, numeric_cols].describe().round(4))

# ========== 7. Save filter configuration ==========
print("\n" + "="*60)
print("7. Save filter configuration")
print("="*60)

filter_config = {
    'sampling_rate_hz': SAMPLING_RATE_HZ,

    'units': {
        'accelerometer': 'm/s² (converted from g)',
        'gyroscope': 'rad/s (converted from deg/s)',
        'conversion': {
            'accelerometer_g_to_ms2': G_TO_MS2,
            'gyroscope_deg_to_rad': DEG2RAD,
        }
    },

    'dtypes': {
        'sensor_channels': 'float32',
        'time_sec': 'float64',
    },

    'accelerometer': {
        'filter_type': 'highpass',
        'purpose': 'detrend (remove gravity)',
        'method': 'Butterworth',
        'cutoff_hz': ACC_HPF_CUTOFF_HZ,
        'order': ACC_HPF_ORDER,
        'coefficients': {
            'b': acc_b.tolist(),
            'a': acc_a.tolist(),
        },
        'zero_phase': True,
    },

    'gyroscope': {
        'filter_type': 'lowpass',
        'purpose': 'denoise',
        'method': 'Butterworth',
        'cutoff_hz': GYR_LPF_CUTOFF_HZ,
        'order': GYR_LPF_ORDER,
        'coefficients': {
            'b': gyr_b.tolist(),
            'a': gyr_a.tolist(),
        },
        'zero_phase': True,
    },

    'clipping': {
        'method': 'Adaptive robust estimation (Median±k·MAD, Scheme A)',
        'target_clip_rate': TARGET_CLIP_RATE,
        'estimated_on': 'all_data',
        'fold_id': None,
        'thresholds': clip_thresholds,
        'actual_clip_stats': clip_stats,
        'rationale': (
            f'Auto-adjust k so the global clipping rate reaches the target '
            f'{TARGET_CLIP_RATE*100:.1f}% over all subjects'
        ),
    },

    'notes': [
        'All filters use filtfilt for zero phase',
        'Filtering is grouped by session + placement, sorted by time_sec; avoid crossing session boundaries',
        'filtfilt_nan_safe filters each contiguous non-NaN run separately',
        'Accelerometer converted from g to m/s² (×9.80665)',
        'Gyroscope converted from deg/s to rad/s (×π/180)',
        f'Adaptive clipping thresholds: determine k on all data so clipping ≈ {TARGET_CLIP_RATE*100:.1f}%, then apply consistently to all data',
        'NaNs remain unchanged',
        'Sensor columns are float32; time_sec is float64',
    ]
}

filter_config_file = configs_dir / "filter.yaml"
with open(filter_config_file, 'w', encoding='utf-8') as f:
    yaml.dump(filter_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
print(f"✓ Saved filter configuration: {filter_config_file}")

filter_config_json = configs_dir / "filter.json"
with open(filter_config_json, 'w', encoding='utf-8') as f:
    json.dump(filter_config, f, indent=2)
print(f"✓ Saved filter configuration: {filter_config_json}")

# ========== 8. Summary ==========
print("\n" + "="*60)
print("Step 6 complete - Sensor preprocessing (global version)")
print("="*60)
print(f"\nConfig:")
print(f"  Units: Acc g→m/s², Gyro deg/s→rad/s")
print(f"  Accelerometer: high-pass {ACC_HPF_CUTOFF_HZ} Hz (remove gravity)")
print(f"  Gyroscope: low-pass {GYR_LPF_CUTOFF_HZ} Hz (denoise)")
print(f"  Clipping: adaptive ±k·MAD (target {TARGET_CLIP_RATE*100:.1f}%)")
print(f"  Clipping thresholds estimated on: all data")
print(f"\nResults:")
print(f"  Output file: {output_file}")
print(f"  Config file: {filter_config_file}")
print(f"  Data shape: {df_clipped.shape}")
print("\nFinal fixes:")
print(f"  ✓ Adaptive clipping thresholds (Scheme A)")
print(f"  ✓ Target clipping rate {TARGET_CLIP_RATE*100:.1f}%, auto-solve k (global)")
print(f"  ✓ Group by placement + sort by time_sec")
print(f"  ✓ Write actual clipping rate into config")
print("="*60)

Step 6: Sensor Preprocessing

Loading resampled data: data/lara/mbientlab/proc/resampled.parquet
Data shape: (560070, 13)
Number of subjects: 8
Number of sessions: 96

0. Unit normalization

Accelerometer unit conversion: g → m/s²
✓ Conversion factor: 9.80665

Gyroscope unit conversion: deg/s → rad/s
✓ Conversion factor: π/180 = 0.017453

1. Design filters

Accelerometer high-pass filter:
  Cutoff frequency: 0.3 Hz
  Order: 2

Gyroscope low-pass filter:
  Cutoff frequency: 20.0 Hz
  Order: 2

2. Apply filters (by session + placement, zero-phase)

Applying accelerometer high-pass (remove gravity)...
✓ Done

Applying gyroscope low-pass (denoise)...
✓ Done

3. Compute adaptive clipping thresholds (target clipping rate)
Estimate clipping thresholds on all data
  Target clip rate: 1.0%

Clipping thresholds (adaptive robust estimation):
  ax:
    center: 0.0152
    scale: 1.3530
    k: 6.712
    range: [-9.0660, 9.0964]
  ay:
    center: 0.0141
    scale: 1.3815
    k: 5.790
    range: [-7.9

In [3]:
import os

"""
Step 7: Coordinate/Magnitude Normalization (top-conf/journal grade)
Compute magnitude channels; z-score standardization (global statistics)
"""

import pandas as pd
import numpy as np
from pathlib import Path
import json
import pickle

# ========== Config ==========
EPSILON = 1e-8  # Prevent division by zero

print("="*60)
print("Step 7: Coordinate/Magnitude Normalization")
print("="*60)

# Load data
proc_dir = Path("data/lara/mbientlab/proc")
configs_dir = Path("configs")

print(f"\nLoading filtered data: {proc_dir / 'filtered.parquet'}")
df = pd.read_parquet(proc_dir / "filtered.parquet")

print(f"Data shape: {df.shape}")
print(f"Number of subjects: {df['subject_id'].nunique()}")
print(f"Number of sessions: {df.groupby(['subject_id', 'session_id'], observed=True).ngroups}")

# ========== 1. Compute derived channels (magnitude) ==========
print("\n" + "="*60)
print("1. Compute derived channels (magnitude)")
print("="*60)

# Accelerometer magnitude
print("\nComputing acc_mag = sqrt(ax² + ay² + az²)...")
df['acc_mag'] = np.sqrt(
    df['ax'].values**2 +
    df['ay'].values**2 +
    df['az'].values**2
).astype('float32')

# Gyroscope magnitude
print("Computing gyr_mag = sqrt(gx² + gy² + gz²)...")
df['gyr_mag'] = np.sqrt(
    df['gx'].values**2 +
    df['gy'].values**2 +
    df['gz'].values**2
).astype('float32')

print(f"✓ Added derived channels: acc_mag, gyr_mag")

# Show derived-channel stats
print("\nDerived channel statistics (post-filter):")
for col in ['acc_mag', 'gyr_mag']:
    valid_data = df[col].dropna()
    if len(valid_data) > 0:
        print(f"  {col}:")
        print(f"    Mean: {valid_data.mean():.4f}")
        print(f"    Std: {valid_data.std():.4f}")
        print(f"    Range: [{valid_data.min():.4f}, {valid_data.max():.4f}]")

# ========== 2. Determine training set (global) ==========
print("\n" + "="*60)
print("2. Determine training set (global)")
print("="*60)

# In this simplified global version, we use ALL subjects to estimate statistics
df_train = df
train_subjects = set(df['subject_id'].unique())
test_subjects = set()  # no explicit test set at this step

print("Compute statistics on all data (no per-fold split)")
print(f"  Samples: {len(df):,}")
print(f"  Subjects: {len(train_subjects)}")

# ========== 3. Compute z-score parameters (global) ==========
print("\n" + "="*60)
print("3. Compute z-score parameters (global)")
print("="*60)

# Channels to standardize
channels_to_normalize = ['ax', 'ay', 'az', 'gx', 'gy', 'gz', 'acc_mag', 'gyr_mag']

# Compute mean and std (valid data only)
scaler_params = {}

print("\nz-score parameters (global):")
for ch in channels_to_normalize:
    if ch not in df_train.columns:
        continue

    valid_data = df_train[ch].dropna().values

    if len(valid_data) > 0:
        mean = float(np.mean(valid_data))
        std = float(np.std(valid_data))

        # Guard against zero std
        if std < EPSILON:
            std = 1.0

        scaler_params[ch] = {
            'mean': mean,
            'std': std,
        }

        print(f"  {ch}:")
        print(f"    Mean: {mean:.6f}")
        print(f"    Std: {std:.6f}")

# ========== 4. Apply z-score standardization ==========
print("\n" + "="*60)
print("4. Apply z-score standardization")
print("="*60)

df_normalized = df.copy()

for ch in channels_to_normalize:
    if ch not in scaler_params:
        continue

    mean = scaler_params[ch]['mean']
    std = scaler_params[ch]['std']

    # Standardize non-NaN values only; cast to float32 to avoid warnings
    mask = df_normalized[ch].notna()
    normalized_values = ((df_normalized.loc[mask, ch] - mean) / (std + EPSILON)).astype('float32')
    df_normalized.loc[mask, ch] = normalized_values

print(f"✓ Standardized {len(scaler_params)} channels")

# Show post-standardization stats (global)
print("\nPost-standardization stats (global):")
for ch in channels_to_normalize:
    if ch not in scaler_params:
        continue

    valid_data = df_normalized[ch].dropna()
    if len(valid_data) > 0:
        print(f"  {ch}:")
        print(f"    Mean: {valid_data.mean():.6f} (should be near 0)")
        print(f"    Std: {valid_data.std():.6f} (should be near 1)")

# ========== 5. Save results ==========
print("\n" + "="*60)
print("5. Save results")
print("="*60)

# Save normalized data
output_file = proc_dir / "normalized.parquet"

# Delete existing directory/file (avoid duplicate appends)
if output_file.exists():
    import shutil
    if output_file.is_dir():
        shutil.rmtree(output_file)
    else:
        output_file.unlink()
    print(f"Removed old data: {output_file}")

df_normalized.to_parquet(
    output_file,
    index=False,
    partition_cols=['subject_id', 'placement'],
    engine='pyarrow'
)
print(f"✓ Saved: {output_file}")
print(f"  Data shape: {df_normalized.shape}")

# Show data preview
print("\nData preview:")
display_cols = ['subject_id', 'session_id', 'ax', 'ay', 'az', 'gx', 'gy', 'gz', 'acc_mag', 'gyr_mag', 'label']
available_cols = [c for c in display_cols if c in df_normalized.columns]
print(df_normalized[available_cols].head(10).to_string())

# Post-standardization numeric stats (overall)
print("\nPost-standardization numeric column stats (overall):")
numeric_cols = ['ax', 'ay', 'az', 'gx', 'gy', 'gz', 'acc_mag', 'gyr_mag']
valid_mask = df_normalized[numeric_cols].notna().all(axis=1)
print(df_normalized.loc[valid_mask, numeric_cols].describe().round(4))

# ========== 6. Save scaler parameters ==========
print("\n" + "="*60)
print("6. Save scaler parameters")
print("="*60)

scaler_info = {
    'fold_id': None,             # global
    'epsilon': EPSILON,
    'train_subjects': sorted(list(train_subjects)),
    'test_subjects': None,       # not defined at this step
    'channels': channels_to_normalize,
    'params': scaler_params,
    'notes': [
        'z-score standardization: (x - mean) / (std + ε)',
        'Mean and std computed from all available samples (global statistics)',
        'If std < ε, set std = 1.0 to avoid divide-by-zero',
        'NaN values are excluded from stats and remain NaN after normalization',
    ]
}

# Save as pickle (global)
scaler_file = proc_dir / "standardization.pkl"
with open(scaler_file, 'wb') as f:
    pickle.dump(scaler_info, f)
print(f"✓ Saved scaler: {scaler_file}")

# Also save as JSON (human-readable)
scaler_json = proc_dir / "standardization.json"
with open(scaler_json, 'w') as f:
    json.dump(scaler_info, f, indent=2)
print(f"✓ Saved scaler: {scaler_json}")

# ========== 7. Validate standardization ==========
print("\n" + "="*60)
print("7. Validate standardization (global)")
print("="*60)

for ch in channels_to_normalize[:3]:  # check first 3 channels only
    if ch in scaler_params:
        valid_data = df_normalized[ch].dropna()
        if len(valid_data) > 0:
            mean_check = valid_data.mean()
            std_check = valid_data.std()
            print(f"  {ch}: mean={mean_check:.6f}, std={std_check:.6f}")

# ========== 8. Summary ==========
print("\n" + "="*60)
print("Step 7 complete - Coordinate/Magnitude Normalization (global)")
print("="*60)
print(f"\nConfig:")
print(f"  Method: z-score standardization (global)")
print(f"  ε (avoid divide-by-zero): {EPSILON}")
print(f"  Standardized channels: {len(scaler_params)}")
print(f"\nResults:")
print(f"  Output data: {output_file}")
print(f"  Scaler (pkl): {scaler_file}")
print(f"  Scaler (json): {scaler_json}")
print(f"  Data shape: {df_normalized.shape}")
print(f"  New columns: acc_mag, gyr_mag")
print("\nRigor guarantees:")
print("  1. ✓ Mean/std computed once on all data (global stats)")
print("  2. ✓ NaNs remain unchanged")
print("  3. ✓ ε={} prevents divide-by-zero".format(EPSILON))
print("  4. ✓ Derived channels acc_mag, gyr_mag")
print("="*60)

Step 7: Coordinate/Magnitude Normalization

Loading filtered data: data/lara/mbientlab/proc/filtered.parquet
Data shape: (560070, 13)
Number of subjects: 8
Number of sessions: 96

1. Compute derived channels (magnitude)

Computing acc_mag = sqrt(ax² + ay² + az²)...
Computing gyr_mag = sqrt(gx² + gy² + gz²)...
✓ Added derived channels: acc_mag, gyr_mag

Derived channel statistics (post-filter):
  acc_mag:
    Mean: 2.9151
    Std: 2.3751
    Range: [0.0073, 14.5074]
  gyr_mag:
    Mean: 1.3122
    Std: 1.1772
    Range: [0.0015, 7.0648]

2. Determine training set (global)
Compute statistics on all data (no per-fold split)
  Samples: 560,070
  Subjects: 8

3. Compute z-score parameters (global)

z-score parameters (global):
  ax:
    Mean: -0.008808
    Std: 2.274599
  ay:
    Mean: 0.006988
    Std: 2.146420
  az:
    Mean: 0.001384
    Std: 2.087749
  gx:
    Mean: -0.002287
    Std: 0.886048
  gy:
    Mean: 0.010105
    Std: 1.160388
  gz:
    Mean: 0.023344
    Std: 0.987726
  acc_ma

In [4]:
#!/usr/bin/env python3

"""
Step 8: Label Alignment & Cleaning (top-conf/journal grade - revised)
Clean NULL/transition, unify to a standard label set, and record mappings
"""

import pandas as pd
import numpy as np
from pathlib import Path
import yaml
import json
from collections import Counter

# ========== Config ==========

# Label cleaning strategy
NULL_STRATEGY = "remove"  # "remove" or "merge_to_transition"
TRANSITION_STRATEGY = "merge_to_nearest"  # "remove" or "merge_to_nearest"

# Unmapped label threshold (abort if exceeded)
UNMAPPED_THRESHOLD = 0.01  # 1%

print("="*60)
print("Step 8: Label Alignment & Cleaning")
print("="*60)

# Create directories
proc_dir = Path("data/lara/mbientlab/proc")
configs_dir = Path("configs")
reports_dir = Path("reports")
reports_dir.mkdir(parents=True, exist_ok=True)

print(f"\nLoading normalized data: {proc_dir / 'normalized.parquet'}")
df = pd.read_parquet(proc_dir / "normalized.parquet")

print(f"Data shape: {df.shape}")
print(f"Number of subjects: {df['subject_id'].nunique()}")

# ========== 1. Analyze original label distribution ==========

print("\n" + "="*60)
print("1. Analyze original label distribution")
print("="*60)

# Count all labels
label_counts = df['label'].value_counts(dropna=False)
total_samples = len(df)
null_count = df['label'].isna().sum()

print(f"\nOriginal label stats:")
print(f"  Total samples: {total_samples:,}")
print(f"  NULL samples: {null_count:,} ({null_count/total_samples*100:.2f}%)")
print(f"  Number of label classes: {df['label'].nunique(dropna=True)}")

print(f"\nLabel distribution (top 20):")
for label, count in label_counts.head(20).items():
    pct = count / total_samples * 100
    print(f"  {str(label):30s}: {count:8,} ({pct:5.2f}%)")

# ========== 2. Define label mapping rules ==========

print("\n" + "="*60)
print("2. Define label mapping rules")
print("="*60)

# Map LARa dataset labels to a cross-dataset unified label superset
# Covers LARa / RealWorld / SHL
LABEL_MAPPING = {
    # Basic activities (shared by RealWorld + LARa)
    1: {"original": "walking", "mapped": "walking", "category": "locomotion"},
    2: {"original": "running", "mapped": "running", "category": "locomotion"},
    3: {"original": "shuffling", "mapped": "walking", "category": "locomotion"},  # merge into walking
    4: {"original": "stairs (ascending)", "mapped": "upstairs", "category": "locomotion"},
    5: {"original": "stairs (descending)", "mapped": "downstairs", "category": "locomotion"},
    6: {"original": "standing", "mapped": "standing", "category": "static"},
    7: {"original": "sitting", "mapped": "sitting", "category": "static"},
    8: {"original": "lying", "mapped": "lying", "category": "static"},

    # Transport (specific to LARa; not in RealWorld)
    13: {"original": "cycling (sit)", "mapped": "cycling", "category": "transport"},
    14: {"original": "cycling (stand)", "mapped": "cycling", "category": "transport"},
    130: {"original": "cycling", "mapped": "cycling", "category": "transport"},

    17: {"original": "car", "mapped": "car", "category": "transport"},
    18: {"original": "bus", "mapped": "bus", "category": "transport"},
    19: {"original": "train", "mapped": "train", "category": "transport"},
    20: {"original": "subway", "mapped": "subway", "category": "transport"},

    # Transition label
    0: {"original": "transition", "mapped": "transition", "category": "transition"},
}

# Cross-dataset unified label superset (LARa + RealWorld + SHL)
UNIFIED_LABELS = {
    "walking": 1,
    "running": 2,
    "sitting": 3,
    "standing": 4,
    "upstairs": 5,
    "downstairs": 6,
    "lying": 7,
    "cycling": 8,
    "car": 9,
    "bus": 10,
    "train": 11,
    "subway": 12,
    "transition": 0,  # kept or cleaned
}

print(f"\nDefined mapping rules: {len(LABEL_MAPPING)} original labels")
print(f"Unified label set: {len(UNIFIED_LABELS)} labels (cross-dataset superset)")

print(f"\nMapping examples:")
for orig_id, info in list(LABEL_MAPPING.items())[:10]:
    print(f"  {orig_id} ({info['original']}) -> {info['mapped']}")

# ========== 3. Audit assertion: check unmapped labels ==========

print("\n" + "="*60)
print("3. Audit assertion: check unmapped labels")
print("="*60)

# Find all original label IDs (excluding NULL)
orig_ids = set(df['label'].dropna().astype(int).unique())
covered_ids = set(LABEL_MAPPING.keys())
unmapped_ids = sorted(orig_ids - covered_ids)

if unmapped_ids:
    # Count samples for unmapped labels
    unmapped_counts = []
    for uid in unmapped_ids:
        count = (df['label'] == uid).sum()
        pct = count / total_samples
        unmapped_counts.append({
            'original_label_id': uid,
            'sample_count': count,
            'percentage': round(pct * 100, 4),
        })

    df_unmapped = pd.DataFrame(unmapped_counts)
    total_unmapped = df_unmapped['sample_count'].sum()
    unmapped_ratio = total_unmapped / total_samples

    # Save list of unmapped labels
    unmapped_file = reports_dir / "unmapped_labels.csv"
    df_unmapped.to_csv(unmapped_file, index=False)

    print(f"\n⚠️ Found unmapped labels: {len(unmapped_ids)}")
    print(f"  Unmapped sample count: {total_unmapped:,} ({unmapped_ratio*100:.2f}%)")
    print(f"  Details saved to: {unmapped_file}")
    print(f"\nList of unmapped labels:")
    print(df_unmapped.to_string(index=False))

    # Abort if threshold exceeded
    if unmapped_ratio > UNMAPPED_THRESHOLD:
        raise RuntimeError(
            f"Unmapped label ratio {unmapped_ratio*100:.2f}% exceeds threshold {UNMAPPED_THRESHOLD*100}%. "
            f"Please check {unmapped_file} and extend LABEL_MAPPING."
        )
    else:
        print(f"\n✓ Unmapped label ratio does not exceed threshold {UNMAPPED_THRESHOLD*100}%; continuing (will mark as NULL)")
else:
    print(f"\n✓ All original labels are covered")

# ========== 4. Apply label mapping ==========

print("\n" + "="*60)
print("4. Apply label mapping")
print("="*60)

df_mapped = df.copy()

# Keep a copy of original labels (nullable integer)
df_mapped['label_original'] = df_mapped['label'].astype('Int32')

# Apply mapping
def map_label(label):
    """Map a single label"""
    if pd.isna(label):
        return np.nan

    label = int(label)
    if label in LABEL_MAPPING:
        mapped_name = LABEL_MAPPING[label]['mapped']
        return UNIFIED_LABELS[mapped_name]
    else:
        # Unknown labels marked as NaN
        return np.nan

df_mapped['label'] = df_mapped['label_original'].apply(map_label)

# Stats after mapping
mapped_label_counts = df_mapped['label'].value_counts(dropna=False)
null_after_mapping = df_mapped['label'].isna().sum()

print(f"\nPost-mapping label stats:")
print(f"  NULL samples: {null_after_mapping:,} ({null_after_mapping/total_samples*100:.2f}%)")
print(f"  Number of valid label classes: {df_mapped['label'].nunique(dropna=True)}")

print(f"\nPost-mapping distribution:")
for label, count in mapped_label_counts.head(15).items():
    pct = count / total_samples * 100
    # find label name
    label_name = "NULL"
    if not pd.isna(label):
        label_name = [k for k, v in UNIFIED_LABELS.items() if v == int(label)][0]
    print(f"  {label_name:15s} ({str(label):2s}): {count:8,} ({pct:5.2f}%)")

# ========== 5. Clean NULL and transition labels (true nearest neighbor) ==========

print("\n" + "="*60)
print("5. Clean NULL and transition labels (true nearest neighbor)")
print("="*60)

df_cleaned = df_mapped.copy()

# Handle NULL labels
if NULL_STRATEGY == "remove":
    null_mask = df_cleaned['label'].isna()
    removed_null = null_mask.sum()
    df_cleaned = df_cleaned[~null_mask].copy()
    print(f"\nNULL handling: removed {removed_null:,} samples")
elif NULL_STRATEGY == "merge_to_transition":
    null_mask = df_cleaned['label'].isna()
    df_cleaned.loc[null_mask, 'label'] = UNIFIED_LABELS['transition']
    print(f"\nNULL handling: merged into transition ({null_mask.sum():,} samples)")

# Detect time column
time_col = None
for candidate in ['time_sec', 'timestamp', 'timestamp_ms', 'time', 'epoch_ms']:
    if candidate in df_cleaned.columns:
        time_col = candidate
        break

if time_col:
    print(f"\nDetected time column: {time_col}")
else:
    print(f"\nNo time column detected; will process by index order")

# Handle transition label (true nearest neighbor)
transition_value = UNIFIED_LABELS['transition']
if TRANSITION_STRATEGY == "remove":
    trans_mask = df_cleaned['label'] == transition_value
    removed_trans = trans_mask.sum()
    df_cleaned = df_cleaned[~trans_mask].copy()
    print(f"Transition handling: removed {removed_trans:,} samples")

elif TRANSITION_STRATEGY == "merge_to_nearest":
    trans_mask = df_cleaned['label'] == transition_value
    trans_count = trans_mask.sum()

    if trans_count > 0:
        print(f"Transition handling: merge {trans_count:,} samples using nearest-neighbor interpolation")

        # Sort by time (ensure nearest-neighbor semantics)
        if time_col:
            df_cleaned = df_cleaned.sort_values(
                ['subject_id', 'session_id', 'placement', time_col],
                kind='stable'
            ).copy()
            print(f"  ✓ Sorted by [{time_col}]")
        else:
            df_cleaned = df_cleaned.sort_index(kind='stable').copy()
            print(f"  ⚠️ Sorted by index (no time column)")

        # True nearest-neighbor merge
        merged_count = 0
        for (subj, sess, plc), group in df_cleaned.groupby(
            ['subject_id', 'session_id', 'placement'], observed=True
        ):
            idx = group.index
            labels = df_cleaned.loc[idx, 'label'].copy()

            # Replace transition with NaN
            labels_with_nan = labels.replace(transition_value, np.nan).astype('float')

            if labels_with_nan.isna().any():
                # Use nearest interpolation (true nearest neighbor)
                labels_filled = labels_with_nan.interpolate(
                    method='nearest',
                    limit_direction='both'
                )

                # Count successfully merged items
                was_trans = (labels == transition_value)
                now_filled = labels_filled.notna()
                merged_this_group = (was_trans & now_filled).sum()
                merged_count += merged_this_group

                # Update labels (round then cast to int)
                df_cleaned.loc[idx, 'label'] = labels_filled.round()

        print(f"  ✓ Successfully merged {merged_count:,} transition samples to nearest labels")

        # Remove transitions that could not be merged (entire segments are transition)
        remaining_trans = (df_cleaned['label'] == transition_value).sum()
        if remaining_trans > 0:
            df_cleaned = df_cleaned[df_cleaned['label'] != transition_value].copy()
            print(f"  ✓ Removed remaining {remaining_trans:,} transition samples that could not be merged")

# Remove remaining NaNs
final_nan = df_cleaned['label'].isna().sum()
if final_nan > 0:
    df_cleaned = df_cleaned[df_cleaned['label'].notna()].copy()
    print(f"\nRemoved final residual NaN samples: {final_nan:,}")

# Cast to int32
df_cleaned['label'] = df_cleaned['label'].astype('int32')

# Reset index
df_cleaned = df_cleaned.reset_index(drop=True)

print(f"\nData after cleaning:")
print(f"  Samples: {len(df_cleaned):,}")
print(f"  Number of label classes: {df_cleaned['label'].nunique()}")
print(f"  Retention rate: {len(df_cleaned)/total_samples*100:.2f}%")

# ========== 6. Audit assertion: verify final label set ==========

print("\n" + "="*60)
print("6. Audit assertion: verify final label set")
print("="*60)

# Determine allowed label set
allowed_labels = set(UNIFIED_LABELS.values())
if TRANSITION_STRATEGY == "remove":
    allowed_labels.discard(UNIFIED_LABELS['transition'])

# Check actual label set
actual_labels = set(df_cleaned['label'].unique())
unexpected = sorted(actual_labels - allowed_labels)

if unexpected:
    raise RuntimeError(
        f"Illegal labels found after cleaning: {unexpected}\n"
        f"Allowed labels: {sorted(allowed_labels)}"
    )
else:
    print(f"✓ Final label set validation passed")
    print(f"  Allowed labels: {sorted(allowed_labels)}")
    print(f"  Actual labels: {sorted(actual_labels)}")

# ========== 7. Final label distribution ==========

print("\n" + "="*60)
print("7. Final label distribution")
print("="*60)

final_label_counts = df_cleaned['label'].value_counts()

print(f"\nFinal label distribution:")
for label_id, count in final_label_counts.items():
    pct = count / len(df_cleaned) * 100
    label_name = [k for k, v in UNIFIED_LABELS.items() if v == int(label_id)][0]
    print(f"  {label_name:15s} ({int(label_id):2d}): {count:8,} ({pct:5.2f}%)")

# By-category statistics
category_stats = {}
for label_id, count in final_label_counts.items():
    label_name = [k for k, v in UNIFIED_LABELS.items() if v == int(label_id)][0]
    # Find category
    category = None
    for orig_id, info in LABEL_MAPPING.items():
        if info['mapped'] == label_name:
            category = info['category']
            break

    if category:
        category_stats[category] = category_stats.get(category, 0) + count

print(f"\nBy-category statistics:")
for category, count in sorted(category_stats.items()):
    pct = count / len(df_cleaned) * 100
    print(f"  {category:15s}: {count:8,} ({pct:5.2f}%)")

# ========== 8. Save results ==========

print("\n" + "="*60)
print("8. Save results")
print("="*60)

# Save cleaned data (using directory layout)
output_dir = proc_dir / "labeled"
if output_dir.exists():
    import shutil
    shutil.rmtree(output_dir)

df_cleaned.to_parquet(
    output_dir,
    index=False,
    partition_cols=['subject_id', 'placement'],
    engine='pyarrow'
)

print(f"✓ Saved: {output_dir}/")
print(f"  Data shape: {df_cleaned.shape}")
print(f"  Partitions: subject_id / placement")

# ========== 9. Save label mapping config (rich) ==========

print("\n" + "="*60)
print("9. Save label mapping config (rich)")
print("="*60)

# Build labels_map with more info
labels_map_data = []
for label_name, label_id in sorted(UNIFIED_LABELS.items(), key=lambda x: x[1]):
    if label_name == "transition" and TRANSITION_STRATEGY == "remove":
        continue  # exclude removed transition

    # Find original label IDs and names
    original_ids = []
    original_names = []
    category = None

    for orig_id, info in LABEL_MAPPING.items():
        if info['mapped'] == label_name:
            original_ids.append(str(orig_id))
            original_names.append(info['original'])
            if category is None:
                category = info['category']

    # Actual sample count
    sample_count = final_label_counts.get(label_id, 0)

    labels_map_data.append({
        'label_id': label_id,
        'label_name': label_name,
        'category': category or 'unknown',
        'sample_count': int(sample_count),
        'percentage': round(sample_count / len(df_cleaned) * 100, 2) if len(df_cleaned) > 0 else 0.0,
        'original_label_ids': ','.join(original_ids) if original_ids else '',
        'original_label_names': '; '.join(original_names) if original_names else '',
        'source_dataset': 'LARa-MbientLab',
        'description': f"{label_name} activity",
    })

df_labels_map = pd.DataFrame(labels_map_data)
labels_map_file = proc_dir / "labels_map.csv"
df_labels_map.to_csv(labels_map_file, index=False)

print(f"✓ Saved label mapping: {labels_map_file}")
print(f"\nLabel mapping table:")
print(df_labels_map.to_string(index=False))

# Save detailed configuration
label_config = {
    'dataset': 'LARa-MbientLab',
    'label_system': 'Cross-dataset unified label superset (covers LARa/RealWorld/SHL)',
    'unified_labels': UNIFIED_LABELS,
    'label_mapping': LABEL_MAPPING,
    'cleaning_strategy': {
        'null_strategy': NULL_STRATEGY,
        'transition_strategy': TRANSITION_STRATEGY,
        'transition_method': 'nearest-neighbor interpolation (true nearest neighbor)' if TRANSITION_STRATEGY == 'merge_to_nearest' else 'remove',
        'time_sorted': time_col is not None,
        'time_column': time_col,
        'unmapped_threshold': UNMAPPED_THRESHOLD,
    },
    'statistics': {
        'original_samples': int(total_samples),
        'cleaned_samples': int(len(df_cleaned)),
        'removed_samples': int(total_samples - len(df_cleaned)),
        'removal_rate': float((total_samples - len(df_cleaned)) / total_samples),
        'original_label_count': int(df['label'].nunique(dropna=True)),
        'final_label_count': int(df_cleaned['label'].nunique()),
        'unmapped_label_count': len(unmapped_ids) if unmapped_ids else 0,
    },
    'label_distribution': {
        label_name: int(final_label_counts.get(label_id, 0))
        for label_name, label_id in UNIFIED_LABELS.items()
        if label_name != 'transition' or TRANSITION_STRATEGY != 'remove'
    },
    'notes': [
        'Label mapping based on cross-dataset unified label superset (LARa + RealWorld + SHL)',
        f'NULL label strategy: {NULL_STRATEGY}',
        f'Transition label strategy: {TRANSITION_STRATEGY} (true nearest-neighbor interpolation)',
        'Unmapped original labels are automatically marked as NULL',
        f'Unmapped label threshold: {UNMAPPED_THRESHOLD*100}%',
        f'Sorted by time column: {time_col if time_col else "No (by index)"}',
        'Mapping table saved at proc/labels_map.csv',
        'label_original column uses nullable integer Int32',
        'Includes audit assertions to ensure label set integrity',
    ]
}

label_config_file = configs_dir / "labels.yaml"
with open(label_config_file, 'w', encoding='utf-8') as f:
    yaml.dump(label_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)

print(f"✓ Saved config: {label_config_file}")

label_config_json = configs_dir / "labels.json"
with open(label_config_json, 'w', encoding='utf-8') as f:
    json.dump(label_config, f, indent=2)

print(f"✓ Saved config: {label_config_json}")

# ========== 10. Summary ==========

print("\n" + "="*60)
print("Step 8 complete - Label alignment & cleaning (top-tier revised)")
print("="*60)

print(f"\nConfig:")
print(f"  Label system: cross-dataset unified superset (LARa/RealWorld/SHL)")
print(f"  NULL strategy: {NULL_STRATEGY}")
print(f"  Transition strategy: {TRANSITION_STRATEGY} (true nearest neighbor)")
print(f"  Unmapped threshold: {UNMAPPED_THRESHOLD*100}%")
print(f"  Time column: {time_col if time_col else 'No (by index)'}")

print(f"\nResults:")
print(f"  Original samples: {total_samples:,}")
print(f"  Cleaned samples: {len(df_cleaned):,}")
print(f"  Removed samples: {total_samples - len(df_cleaned):,}")
print(f"  Retention rate: {len(df_cleaned)/total_samples*100:.2f}%")

print(f"\nLabel stats:")
print(f"  Original label classes: {df['label'].nunique(dropna=True)}")
print(f"  Final label classes: {df_cleaned['label'].nunique()}")
print(f"  Unmapped labels: {len(unmapped_ids) if unmapped_ids else 0}")

print(f"\nOutputs:")
print(f"  Data: {output_dir}/")
print(f"  Mapping table: {labels_map_file}")
print(f"  Config: {label_config_file}")
if unmapped_ids:
    print(f"  Unmapped list: {reports_dir / 'unmapped_labels.csv'}")

print("\nKey fixes (top-tier):")
print("  1. ✓ True nearest-neighbor merge (interpolate method='nearest')")
print("  2. ✓ Sort by time before processing (correct semantics)")
print("  3. ✓ Record unmapped labels to reports/unmapped_labels.csv")
print("  4. ✓ label_original uses nullable Int32")
print("  5. ✓ Removed irrelevant MAJORITY_VOTE_THRESHOLD")
print("  6. ✓ Audit assertions (fail-fast)")
print("  7. ✓ labels_map.csv includes original names and source")
print("  8. ✓ Label system described as cross-dataset superset")
print("  9. ✓ Output directory changed to labeled/")
print("="*60)

Step 8: Label Alignment & Cleaning

Loading normalized data: data/lara/mbientlab/proc/normalized.parquet
Data shape: (560070, 15)
Number of subjects: 8

1. Analyze original label distribution

Original label stats:
  Total samples: 560,070
  NULL samples: 28 (0.00%)
  Number of label classes: 8

Label distribution (top 20):
  4.0                           :  215,988 (38.56%)
  2.0                           :   86,132 (15.38%)
  0.0                           :   73,180 (13.07%)
  7.0                           :   48,927 ( 8.74%)
  5.0                           :   43,966 ( 7.85%)
  1.0                           :   42,039 ( 7.51%)
  3.0                           :   38,224 ( 6.82%)
  6.0                           :   11,586 ( 2.07%)
  nan                           :       28 ( 0.00%)

2. Define label mapping rules

Defined mapping rules: 16 original labels
Unified label set: 13 labels (cross-dataset superset)

Mapping examples:
  1 (walking) -> walking
  2 (running) -> running
  3 (shuf

In [5]:
import os

"""
Step 9: Sliding-window Slicing (top-conf/journal grade - multi-fold version)
Slice with fixed window length/step; assign window label by majority label
For each fold in configs/splits.json, generate windows/{fold_xx}/X_train.npy, X_test.npy, etc.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import yaml
import json
from collections import Counter

# ========== Config ==========

# Sliding-window parameters
SAMPLING_RATE_HZ = 50.0
WINDOW_SIZE_SEC = 3.0
OVERLAP_RATIO = 0.5

# Compute sample counts
WINDOW_SIZE = int(WINDOW_SIZE_SEC * SAMPLING_RATE_HZ)        # 150 samples
STEP_SIZE = int(WINDOW_SIZE * (1 - OVERLAP_RATIO))           # 75 samples

# Majority label threshold
DOMINANT_THRESHOLD = 0.8

# Feature columns (8 channels)
FEATURE_COLS = ['ax', 'ay', 'az', 'gx', 'gy', 'gz', 'acc_mag', 'gyr_mag']

print("="*60)
print("Step 9: Sliding-window slicing (multi-fold)")
print("="*60)

# Base directories
proc_dir = Path("data/lara/mbientlab/proc")
configs_dir = Path("configs")
windows_root = proc_dir / "windows"
windows_root.mkdir(parents=True, exist_ok=True)

print(f"\nSliding-window parameters:")
print(f"  Window length: {WINDOW_SIZE_SEC} s = {WINDOW_SIZE} samples @ {SAMPLING_RATE_HZ} Hz")
print(f"  Step size: {STEP_SIZE} samples (overlap {OVERLAP_RATIO*100:.0f}%)")
print(f"  Dominant label threshold: {DOMINANT_THRESHOLD*100:.0f}%")
print(f"  Feature columns: {FEATURE_COLS}")

# ========== 1. Load data ==========

print("\n" + "="*60)
print("1. Load cleaned & labeled data")
print("="*60)

labeled_dir = proc_dir / "labeled"
print(f"Loading data from: {labeled_dir}/")
df = pd.read_parquet(labeled_dir)

print(f"Data shape: {df.shape}")
print(f"Number of subjects: {df['subject_id'].nunique()}")
print(f"Number of label classes: {df['label'].nunique()}")

# Check required columns
required_cols = ['subject_id', 'session_id', 'placement', 'label'] + FEATURE_COLS
missing_cols = [c for c in required_cols if c not in df.columns]
if missing_cols:
    raise ValueError(f"Missing required columns: {missing_cols}")

# Detect time column
time_col = 'time_sec' if 'time_sec' in df.columns else None
if time_col:
    print(f"Time column: {time_col}")
else:
    print("No time column detected; will sort by index")

# ========== 2. Load splits and determine folds ==========

print("\n" + "="*60)
print("2. Load train/test splits (folds)")
print("="*60)

splits_path = configs_dir / "splits.json"
fold_ids = []
splits = None

if splits_path.exists():
    with open(splits_path, "r") as f:
        splits = json.load(f)

    # Expect keys like "0","1","2",...
    fold_ids = sorted(int(k) for k in splits.keys())
    print(f"Detected {len(fold_ids)} folds from {splits_path}: {fold_ids}")
else:
    # Fallback: single "all" fold (no LOSO config)
    print("⚠️ splits.json not found; will treat all data as a single 'all' fold")
    fold_ids = [None]

# ========== 3. Sliding-window function (with time continuity check) ==========

def sliding_window_extract(df_subset, window_size, step_size, dominant_threshold, time_col=None):
    """
    Perform sliding-window slicing grouped by session.

    Returns:
        windows_list: list of window feature arrays
        metadata_list: list of window metadata dicts
    """
    windows_list = []
    metadata_list = []
    window_id = 0

    # Group by session + placement
    for (subj, sess, plc), group in df_subset.groupby(
        ['subject_id', 'session_id', 'placement'], observed=True
    ):
        # Sort by time column (preferred), otherwise by index
        if time_col and time_col in group.columns:
            group = group.sort_values(time_col, kind='stable').copy()
        else:
            group = group.sort_index(kind='stable').copy()

        # Extract features and labels
        features = group[FEATURE_COLS].values
        labels = group['label'].values

        # Extract timestamps (if any)
        if time_col and time_col in group.columns:
            timestamps = group[time_col].values
        else:
            timestamps = None

        # Sliding-window slicing
        n_samples = len(group)
        for start_idx in range(0, n_samples - window_size + 1, step_size):
            end_idx = start_idx + window_size

            # Extract window
            window_features = features[start_idx:end_idx]
            window_labels = labels[start_idx:end_idx]

            # Check NaNs
            if np.isnan(window_features).any():
                continue

            # Time continuity check (if timestamps exist)
            if timestamps is not None:
                expected_duration = (window_size - 1) / SAMPLING_RATE_HZ
                actual_duration = timestamps[end_idx - 1] - timestamps[start_idx]
                # Allow 10% jitter
                if abs(actual_duration - expected_duration) > 0.1 * expected_duration:
                    continue

            # Compute dominant label
            label_counts = Counter(window_labels)
            dominant_label, dominant_count = label_counts.most_common(1)[0]
            dominant_ratio = dominant_count / window_size

            # Keep only windows that meet the threshold
            if dominant_ratio < dominant_threshold:
                continue

            # Extract time range
            if timestamps is not None:
                time_start = timestamps[start_idx]
                time_end = timestamps[end_idx - 1]
                time_range = f"{time_start:.3f}-{time_end:.3f}"
            else:
                time_range = f"{start_idx}-{end_idx-1}"

            # Save window
            windows_list.append(window_features)

            # Save metadata
            metadata_list.append({
                'window_id': window_id,
                'subject_id': subj,
                'session_id': sess,
                'placement': plc,
                'label': int(dominant_label),
                'label_purity': round(dominant_ratio, 4),
                'time_range': time_range,
                'start_idx': start_idx,
                'end_idx': end_idx,
            })

            window_id += 1

    return windows_list, metadata_list

# ========== 4. Loop over folds and extract windows ==========

print("\n" + "="*60)
print("3. Extract windows for each fold")
print("="*60)

fold_stats = {}  # for global config (per-fold statistics)

for fold_id in fold_ids:
    if fold_id is None:
        fold_tag = "all"
        print(f"\n--- Processing pseudo-fold: {fold_tag} (all data) ---")

        train_subjects = set(df['subject_id'].unique())
        test_subjects = set()
        df_train = df.copy()
        df_test = pd.DataFrame()
    else:
        fold_tag = f"fold_{fold_id:02d}"
        print(f"\n--- Processing fold {fold_id} ({fold_tag}) ---")

        fold_cfg = splits[str(fold_id)]
        train_subjects = set(fold_cfg["train_subjects"])
        test_subjects = set(fold_cfg["test_subjects"])

        # Split data
        df_train = df[df['subject_id'].isin(train_subjects)].copy()
        df_test = df[df['subject_id'].isin(test_subjects)].copy()

        print(f"  Train subjects: {len(train_subjects)}")
        print(f"  Test subjects:  {len(test_subjects)}")
        print(f"  Train samples:  {len(df_train):,}")
        print(f"  Test samples:   {len(df_test):,}")

    # Create output directory for this fold
    windows_dir = windows_root / fold_tag
    windows_dir.mkdir(parents=True, exist_ok=True)
    print(f"  Output directory: {windows_dir}")

    # ----- 4.1 Extract training-set windows -----
    print("\n  [Train] Sliding-window extraction")

    train_windows = []
    df_train_meta = pd.DataFrame()
    train_label_counts = pd.Series(dtype=int)

    if not df_train.empty:
        print(f"  Processing train set ({len(df_train):,} samples)...")

        train_windows, train_metadata = sliding_window_extract(
            df_train, WINDOW_SIZE, STEP_SIZE, DOMINANT_THRESHOLD, time_col
        )

        print(f"  ✓ Extracted train windows: {len(train_windows):,}")

        if train_windows:
            # To numpy array
            X_train = np.array(train_windows, dtype='float32')  # (n_windows, window_size, n_features)
            df_train_meta = pd.DataFrame(train_metadata)

            print(f"    X_train shape: {X_train.shape}")
            print(f"    Feature dims : {X_train.shape[2]} channels × {X_train.shape[1]} timesteps")

            # Label distribution
            train_label_counts = df_train_meta['label'].value_counts().sort_index()
            print(f"\n    Train-set label distribution:")
            for label, count in train_label_counts.items():
                pct = count / len(df_train_meta) * 100
                print(f"      Label {label}: {count:6,} windows ({pct:5.2f}%)")

            # Label purity stats
            avg_purity = df_train_meta['label_purity'].mean()
            min_purity = df_train_meta['label_purity'].min()
            print(f"\n    Train-set label purity:")
            print(f"      Mean: {avg_purity*100:.2f}%")
            print(f"      Min:  {min_purity*100:.2f}%")

            # Save train set
            print(f"\n    Saving train set...")

            # Save features (numpy)
            X_train_npy_file = windows_dir / "X_train.npy"
            np.save(X_train_npy_file, X_train)
            print(f"      ✓ {X_train_npy_file} (feature tensor)")

            # Save metadata (Parquet)
            X_train_meta_file = windows_dir / "X_train.parquet"
            df_train_meta[['window_id', 'subject_id', 'session_id', 'placement',
                           'label', 'label_purity', 'time_range', 'start_idx', 'end_idx']].to_parquet(
                X_train_meta_file, index=False
            )
            print(f"      ✓ {X_train_meta_file} (metadata)")

            # Save label vector
            y_train = df_train_meta['label'].values.astype('int32')
            y_train_file = windows_dir / "y_train.npy"
            np.save(y_train_file, y_train)
            print(f"      ✓ {y_train_file}")

            # Export label distribution snapshot (for audit)
            train_label_counts.to_csv(windows_dir / "train_label_counts.csv", header=['count'])
            print(f"      ✓ train_label_counts.csv")
        else:
            print("  ⚠️ No train windows extracted")
    else:
        print("  Train set is empty; skipping")

    # ----- 4.2 Extract test-set windows -----
    print("\n  [Test] Sliding-window extraction")

    test_windows = []
    df_test_meta = pd.DataFrame()
    test_label_counts = pd.Series(dtype=int)

    if not df_test.empty:
        print(f"  Processing test set ({len(df_test):,} samples)...")

        test_windows, test_metadata = sliding_window_extract(
            df_test, WINDOW_SIZE, STEP_SIZE, DOMINANT_THRESHOLD, time_col
        )

        print(f"  ✓ Extracted test windows: {len(test_windows):,}")

        if test_windows:
            # To numpy array
            X_test = np.array(test_windows, dtype='float32')
            df_test_meta = pd.DataFrame(test_metadata)

            print(f"    X_test shape: {X_test.shape}")

            # Label distribution
            test_label_counts = df_test_meta['label'].value_counts().sort_index()
            print(f"\n    Test-set label distribution:")
            for label, count in test_label_counts.items():
                pct = count / len(df_test_meta) * 100
                print(f"      Label {label}: {count:6,} windows ({pct:5.2f}%)")

            # Label purity stats
            avg_purity = df_test_meta['label_purity'].mean()
            min_purity = df_test_meta['label_purity'].min()
            print(f"\n    Test-set label purity:")
            print(f"      Mean: {avg_purity*100:.2f}%")
            print(f"      Min:  {min_purity*100:.2f}%")

            # Save test set
            print(f"\n    Saving test set...")

            # Save features (numpy)
            X_test_npy_file = windows_dir / "X_test.npy"
            np.save(X_test_npy_file, X_test)
            print(f"      ✓ {X_test_npy_file} (feature tensor)")

            # Save metadata (Parquet)
            X_test_meta_file = windows_dir / "X_test.parquet"
            df_test_meta[['window_id', 'subject_id', 'session_id', 'placement',
                          'label', 'label_purity', 'time_range', 'start_idx', 'end_idx']].to_parquet(
                X_test_meta_file, index=False
            )
            print(f"      ✓ {X_test_meta_file} (metadata)")

            # Save label vector
            y_test = df_test_meta['label'].values.astype('int32')
            y_test_file = windows_dir / "y_test.npy"
            np.save(y_test_file, y_test)
            print(f"      ✓ {y_test_file}")

            # Export label distribution snapshot (for audit)
            test_label_counts.to_csv(windows_dir / "test_label_counts.csv", header=['count'])
            print(f"      ✓ test_label_counts.csv")
        else:
            print("  ⚠️ No test windows extracted")
    else:
        print("  Test set is empty; skipping")

    # ----- 4.3 Collect statistics for this fold -----
    fold_key = "all" if fold_id is None else str(fold_id)
    fold_stats[fold_key] = {}

    if train_windows:
        fold_stats[fold_key]['train'] = {
            'n_windows': int(len(train_windows)),
            'n_subjects': int(df_train_meta['subject_id'].nunique()),
            'n_sessions': int(df_train_meta.groupby(['subject_id', 'session_id']).ngroups),
            'label_distribution': {int(k): int(v) for k, v in train_label_counts.items()},
            'avg_label_purity': round(float(df_train_meta['label_purity'].mean()), 4),
            'min_label_purity': round(float(df_train_meta['label_purity'].min()), 4),
        }

    if test_windows:
        fold_stats[fold_key]['test'] = {
            'n_windows': int(len(test_windows)),
            'n_subjects': int(df_test_meta['subject_id'].nunique()),
            'n_sessions': int(df_test_meta.groupby(['subject_id', 'session_id']).ngroups),
            'label_distribution': {int(k): int(v) for k, v in test_label_counts.items()},
            'avg_label_purity': round(float(df_test_meta['label_purity'].mean()), 4),
            'min_label_purity': round(float(df_test_meta['label_purity'].min()), 4),
        }

# ========== 5. Save window configuration (global, multi-fold) ==========

print("\n" + "="*60)
print("4. Save window configuration (global)")
print("="*60)

fold_ids_str = ["all" if fid is None else str(fid) for fid in fold_ids]

window_config = {
    'window_parameters': {
        'sampling_rate_hz': SAMPLING_RATE_HZ,
        'window_size_sec': WINDOW_SIZE_SEC,
        'window_size_samples': WINDOW_SIZE,
        'overlap_ratio': OVERLAP_RATIO,
        'step_size_samples': STEP_SIZE,
        'dominant_threshold': DOMINANT_THRESHOLD,
    },
    'features': {
        'channels': FEATURE_COLS,
        'n_channels': len(FEATURE_COLS),
        'description': '8-channel IMU features (ax,ay,az,gx,gy,gz,acc_mag,gyr_mag)',
    },
    'dataset_split': {
        'num_folds': len(fold_ids),
        'fold_ids': fold_ids_str,
        'source': str(splits_path) if splits_path.exists() else None,
    },
    'statistics': fold_stats,
    'notes': [
        f'Window parameters: {WINDOW_SIZE_SEC}s @ {SAMPLING_RATE_HZ}Hz = {WINDOW_SIZE} samples',
        f'Step size: {STEP_SIZE} samples (overlap {OVERLAP_RATIO*100:.0f}%)',
        f'Dominant label threshold: {DOMINANT_THRESHOLD*100:.0f}% (discard windows below threshold)',
        'Features: 8 channels (3-axis accelerometer + 3-axis gyroscope + 2 magnitudes)',
        'Data formats: X_*.npy (float32 tensor), X_*.parquet (metadata), y_*.npy (int32)',
        'Metadata includes: window_id/time_range/label/label_purity, etc.',
        'Slice per session to ensure temporal continuity',
        f'Order by {time_col if time_col else "index"}',
        'Discard windows containing NaN',
        'Time continuity check (allow 10% jitter)',
        'Persist by fold: windows/fold_xx/ (avoid overwrite when looping over folds)',
    ]
}

window_config_file = configs_dir / "windows.yaml"
with open(window_config_file, 'w', encoding='utf-8') as f:
    yaml.dump(window_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)
print(f"✓ Saved config: {window_config_file}")

window_config_json = configs_dir / "windows.json"
with open(window_config_json, 'w', encoding='utf-8') as f:
    json.dump(window_config, f, indent=2)
print(f"✓ Saved config: {window_config_json}")

# ========== 6. Summary ==========

print("\n" + "="*60)
print("Step 9 complete - Sliding-window slicing (multi-fold)")
print("="*60)

print(f"\nWindow parameters:")
print(f"  Window length: {WINDOW_SIZE_SEC} s = {WINDOW_SIZE} samples")
print(f"  Step size: {STEP_SIZE} samples (overlap {OVERLAP_RATIO*100:.0f}%)")
print(f"  Dominant threshold: {DOMINANT_THRESHOLD*100:.0f}%")
print(f"  Feature dimension: {len(FEATURE_COLS)} channels")
print(f"  Sort order: {time_col if time_col else 'index'}")

print("\nPer-fold window statistics:")
for fold_key, stats in fold_stats.items():
    print(f"\n  Fold {fold_key}:")
    if 'train' in stats:
        tr = stats['train']
        print(f"    Train: {tr['n_windows']} windows, "
              f"{tr['n_subjects']} subjects, {tr['n_sessions']} sessions, "
              f"avg purity {tr['avg_label_purity']*100:.2f}%")
    else:
        print("    Train: (no windows)")
    if 'test' in stats:
        te = stats['test']
        print(f"    Test : {te['n_windows']} windows, "
              f"{te['n_subjects']} subjects, {te['n_sessions']} sessions, "
              f"avg purity {te['avg_label_purity']*100:.2f}%")
    else:
        print("    Test : (no windows)")

print(f"\nOutputs per fold:")
print(f"  Root directory: {windows_root}/")
print(f"  For each fold: X_train.npy, X_train.parquet, y_train.npy, "
      f"train_label_counts.csv (+ test equivalents when applicable)")
print(f"  Global config: {window_config_file}, {window_config_json}")
print("="*60)

Step 9: Sliding-window slicing (multi-fold)

Sliding-window parameters:
  Window length: 3.0 s = 150 samples @ 50.0 Hz
  Step size: 75 samples (overlap 50%)
  Dominant label threshold: 80%
  Feature columns: ['ax', 'ay', 'az', 'gx', 'gy', 'gz', 'acc_mag', 'gyr_mag']

1. Load cleaned & labeled data
Loading data from: data/lara/mbientlab/proc/labeled/
Data shape: (556504, 16)
Number of subjects: 8
Number of label classes: 6
Time column: time_sec

2. Load train/test splits (folds)
Detected 8 folds from configs/splits.json: [0, 1, 2, 3, 4, 5, 6, 7]

3. Extract windows for each fold

--- Processing fold 0 (fold_00) ---
  Train subjects: 7
  Test subjects:  1
  Train samples:  479,982
  Test samples:   76,522
  Output directory: data/lara/mbientlab/proc/windows/fold_00

  [Train] Sliding-window extraction
  Processing train set (479,982 samples)...
  ✓ Extracted train windows: 4,965
    X_train shape: (4965, 150, 8)
    Feature dims : 8 channels × 150 timesteps

    Train-set label distribut

In [6]:
#!/usr/bin/env python3

"""
Step 10: LOSO Split (top-conf/journal grade)
Leave-One-Subject-Out: 1 subject for test per fold, the rest for training
"""

import pandas as pd
import numpy as np
from pathlib import Path
import json
import yaml
from collections import defaultdict

print("="*60)
print("Step 10: LOSO split")
print("="*60)

# Path configuration
proc_dir = Path("data/lara/mbientlab/proc")
configs_dir = Path("configs")
configs_dir.mkdir(parents=True, exist_ok=True)

# ========== 1. Load data and get subject list ==========

print("\n" + "="*60)
print("1. Load data and get subject list")
print("="*60)

labeled_dir = proc_dir / "labeled"
print(f"Loading data: {labeled_dir}/")

df = pd.read_parquet(labeled_dir)

print(f"Data shape: {df.shape}")
print(f"Total samples: {len(df):,}")

# Extract all subjects
all_subjects = sorted(df['subject_id'].unique().tolist())
n_subjects = len(all_subjects)

print(f"\nSubject list:")
print(f"  Total: {n_subjects} subjects")
print(f"  IDs: {all_subjects}")

# Sample count per subject
subject_sample_counts = df['subject_id'].value_counts().sort_index()
print(f"\nSample count per subject:")
for subj in all_subjects:
    count = subject_sample_counts.get(subj, 0)
    pct = count / len(df) * 100
    print(f"  {subj}: {count:8,} samples ({pct:5.2f}%)")

# ========== 2. Generate LOSO split ==========

print("\n" + "="*60)
print("2. Generate LOSO split")
print("="*60)

print(f"\nLOSO strategy: Leave-One-Subject-Out")
print(f"  #folds = #subjects = {n_subjects}")
print(f"  Per fold: 1 subject for test, {n_subjects-1} subjects for train")

# Create split dict
splits = {}

for fold_id, test_subject in enumerate(all_subjects):
    # Test set: current subject
    test_subjects = [test_subject]

    # Train set: all other subjects
    train_subjects = [s for s in all_subjects if s != test_subject]

    # Save split
    splits[str(fold_id)] = {
        "fold_id": fold_id,
        "test_subject": test_subject,
        "test_subjects": test_subjects,  # list for compatibility
        "train_subjects": train_subjects,
        "n_train": len(train_subjects),
        "n_test": len(test_subjects),
    }

    print(f"  Fold {fold_id}: test {test_subject}, train {len(train_subjects)} subjects")

print(f"\n✓ Generated {len(splits)} LOSO folds")

# ========== 3. Validate split integrity ==========

print("\n" + "="*60)
print("3. Validate split integrity")
print("="*60)

# Check 1: each subject appears exactly once in the test set
test_subject_appearances = defaultdict(int)
for fold_id, fold_info in splits.items():
    for subj in fold_info['test_subjects']:
        test_subject_appearances[subj] += 1

print(f"\nCheck 1: times each subject appears as test")
all_once = True
for subj in all_subjects:
    count = test_subject_appearances[subj]
    status = "✓" if count == 1 else "✗"
    print(f"  {status} {subj}: {count} time(s)")
    if count != 1:
        all_once = False

if all_once:
    print(f"  ✓ All subjects appear exactly once")
else:
    raise RuntimeError("Split validation failed: subject test appearances not equal to 1")

# Check 2: train and test sets are disjoint
print(f"\nCheck 2: train and test sets are disjoint")
all_disjoint = True
for fold_id, fold_info in splits.items():
    train_set = set(fold_info['train_subjects'])
    test_set = set(fold_info['test_subjects'])
    overlap = train_set & test_set

    if overlap:
        print(f"  ✗ Fold {fold_id}: overlap exists {overlap}")
        all_disjoint = False

if all_disjoint:
    print(f"  ✓ Train/test sets are completely disjoint for all folds")
else:
    raise RuntimeError("Split validation failed: train and test sets have overlap")

# Check 3: all subjects covered
print(f"\nCheck 3: all subjects covered")
covered_subjects = set()
for fold_id, fold_info in splits.items():
    covered_subjects.update(fold_info['train_subjects'])
    covered_subjects.update(fold_info['test_subjects'])

missing = set(all_subjects) - covered_subjects
extra = covered_subjects - set(all_subjects)

if not missing and not extra:
    print(f"  ✓ All subjects are covered; no missing or extra subjects")
else:
    if missing:
        print(f"  ✗ Missing subjects: {missing}")
    if extra:
        print(f"  ✗ Extra subjects: {extra}")
    raise RuntimeError("Split validation failed: subject coverage incomplete")

# Check 4: sample count stats
print(f"\nCheck 4: per-fold sample counts")
fold_sample_stats = []
for fold_id, fold_info in splits.items():
    train_subjects = fold_info['train_subjects']
    test_subjects = fold_info['test_subjects']

    n_train_samples = df[df['subject_id'].isin(train_subjects)].shape[0]
    n_test_samples = df[df['subject_id'].isin(test_subjects)].shape[0]

    fold_sample_stats.append({
        'fold_id': int(fold_id),
        'test_subject': fold_info['test_subject'],
        'n_train_samples': n_train_samples,
        'n_test_samples': n_test_samples,
        'train_ratio': round(n_train_samples / len(df), 4),
        'test_ratio': round(n_test_samples / len(df), 4),
    })

df_fold_stats = pd.DataFrame(fold_sample_stats)

print(f"\nPer-fold sample distribution:")
print(df_fold_stats.to_string(index=False))

# Summary
print(f"\nSample distribution summary:")
print(f"  Train sample count: {df_fold_stats['n_train_samples'].min():,} ~ {df_fold_stats['n_train_samples'].max():,}")
print(f"  Test sample count: {df_fold_stats['n_test_samples'].min():,} ~ {df_fold_stats['n_test_samples'].max():,}")
print(f"  Average train ratio: {df_fold_stats['train_ratio'].mean()*100:.2f}%")
print(f"  Average test ratio: {df_fold_stats['test_ratio'].mean()*100:.2f}%")

print(f"\n✓ All validations passed")

# ========== 4. Save split configuration ==========

print("\n" + "="*60)
print("4. Save split configuration")
print("="*60)

# Save splits.json
splits_file = configs_dir / "splits.json"
with open(splits_file, 'w', encoding='utf-8') as f:
    json.dump(splits, f, indent=2)

print(f"✓ Saved: {splits_file}")

# Save detailed config (with metadata)
loso_config = {
    'strategy': 'LOSO (Leave-One-Subject-Out)',
    'description': 'One subject for test in each fold; remaining subjects for training',
    'n_folds': n_subjects,
    'n_subjects': n_subjects,
    'all_subjects': all_subjects,
    'fold_statistics': {
        'train_samples_min': int(df_fold_stats['n_train_samples'].min()),
        'train_samples_max': int(df_fold_stats['n_train_samples'].max()),
        'train_samples_mean': int(df_fold_stats['n_train_samples'].mean()),
        'test_samples_min': int(df_fold_stats['n_test_samples'].min()),
        'test_samples_max': int(df_fold_stats['n_test_samples'].max()),
        'test_samples_mean': int(df_fold_stats['n_test_samples'].mean()),
        'avg_train_ratio': round(float(df_fold_stats['train_ratio'].mean()), 4),
        'avg_test_ratio': round(float(df_fold_stats['test_ratio'].mean()), 4),
    },
    'validation': {
        'no_subject_overlap': True,
        'all_subjects_covered': True,
        'each_subject_tested_once': True,
    },
    'anti_leakage_principles': [
        'Train and test sets are completely separated by subject',
        'Window slicing is performed after splitting to ensure no cross-fold leakage',
        'Statistics (mean/std) are computed from the training fold only',
        'Feature engineering is performed independently within each fold',
        'Hyperparameter tuning uses training-fold data only (nested CV optional)',
        'Final model evaluation is strictly based on the corresponding fold’s test set',
        'When aggregating results across folds, use metrics from independent test sets',
    ],
    'notes': [
        f'LOSO split: {n_subjects} folds; 1 subject per fold for test',
        'Ensure each subject appears exactly once in the test set',
        'Train/test sets are mutually exclusive with no subject overlap',
        'Suitable for small-sample settings with large inter-subject variability',
        'Report mean and standard deviation across all folds',
    ]
}

loso_config_file = configs_dir / "loso.yaml"
with open(loso_config_file, 'w', encoding='utf-8') as f:
    yaml.dump(loso_config, f, default_flow_style=False, allow_unicode=True, sort_keys=False)

print(f"✓ Saved: {loso_config_file}")

loso_config_json = configs_dir / "loso.json"
with open(loso_config_json, 'w', encoding='utf-8') as f:
    json.dump(loso_config, f, indent=2)

print(f"✓ Saved: {loso_config_json}")

# Save per-fold sample stats
fold_stats_file = configs_dir / "loso_fold_stats.csv"
df_fold_stats.to_csv(fold_stats_file, index=False)
print(f"✓ Saved: {fold_stats_file}")

# ========== 5. Generate usage example ==========

print("\n" + "="*60)
print("5. Generate usage example")
print("="*60)

example_code = '''
# ========== LOSO Usage Example ==========

import json
from pathlib import Path

# 1. Load splits
with open("configs/splits.json", "r") as f:
    splits = json.load(f)

# 2. Iterate over folds
for fold_id in range(len(splits)):
    print(f"\\n========== Fold {fold_id} ==========")

    # Get current fold split
    fold = splits[str(fold_id)]
    train_subjects = fold["train_subjects"]
    test_subject = fold["test_subject"]

    print(f"Train: {len(train_subjects)} subjects")
    print(f"Test: {test_subject}")

    # 3. Set environment variable (used by later steps)
    import os
    os.environ["FOLD_ID"] = str(fold_id)

    # 4. Run training pipeline
    # - Step 6: per-fold clipping (statistics from train only)
    # - Step 7: per-fold standardization (statistics from train only)
    # - Step 9: per-fold windowing
    # - Train model (training windows only)
    # - Evaluate model (test windows only)

    # 5. Save results of current fold
    # results[fold_id] = {"accuracy": acc, "f1": f1, ...}

# 6. Aggregate results across folds
# mean_acc = np.mean([r["accuracy"] for r in results.values()])
# std_acc = np.std([r["accuracy"] for r in results.values()])
# print(f"Mean accuracy: {mean_acc:.4f} ± {std_acc:.4f}")

# ========== Anti-leakage Checklist ==========
# ✓ Train/test separated by subject
# ✓ Statistics (mean/std) computed from training set only
# ✓ Feature scaling uses parameters from training set
# ✓ Windowing performed after splitting
# ✓ Hyperparameter tuning uses training data only
# ✓ Test set used strictly for final evaluation
'''

example_file = configs_dir / "loso_usage_example.py"
with open(example_file, 'w', encoding='utf-8') as f:
    f.write(example_code)

print(f"✓ Generated usage example: {example_file}")

print("\nHow to use:")
print("  1. export FOLD_ID=0  # set current fold")
print("  2. Run steps 6–9 (they will use the corresponding fold automatically)")
print("  3. Train the model and evaluate")
print("  4. Repeat steps 1–3 for all folds")
print("  5. Aggregate results (mean ± std)")

# ========== 6. Split visualization info ==========

print("\n" + "="*60)
print("6. Split visualization info")
print("="*60)

print(f"\nLOSO split matrix (first 5 folds):")
print(f"{'Fold':<6} {'TestSubject':<12} {'#TrainSubs':<12} {'#TestSamples':<12} {'#TrainSamples':<12}")
print("-" * 60)

for i in range(min(5, len(splits))):
    fold = splits[str(i)]
    stats = df_fold_stats[df_fold_stats['fold_id'] == i].iloc[0]
    print(f"{i:<6} {fold['test_subject']:<12} {fold['n_train']:<12} "
          f"{stats['n_test_samples']:<12} {stats['n_train_samples']:<12}")

if len(splits) > 5:
    print(f"... (total {len(splits)} folds)")

# ========== 7. Summary ==========

print("\n" + "="*60)
print("Step 10 complete - LOSO split")
print("="*60)

print(f"\nSplit strategy:")
print(f"  Method: LOSO (Leave-One-Subject-Out)")
print(f"  #folds: {n_subjects}")
print(f"  #subjects: {n_subjects}")
print(f"  Train per fold: {n_subjects-1} subjects")
print(f"  Test per fold: 1 subject")

print(f"\nData distribution:")
print(f"  Total samples: {len(df):,}")
print(f"  Train ratio (avg): {df_fold_stats['train_ratio'].mean()*100:.2f}%")
print(f"  Test ratio (avg): {df_fold_stats['test_ratio'].mean()*100:.2f}%")

print(f"\nValidation results:")
print(f"  ✓ No subject overlap")
print(f"  ✓ All subjects covered")
print(f"  ✓ Each subject tested exactly once")
print(f"  ✓ Train/test sets are disjoint")

print(f"\nOutput files:")
print(f"  Main config: {splits_file}")
print(f"  Detailed config: {loso_config_file}")
print(f"  Fold stats: {fold_stats_file}")
print(f"  Usage example: {example_file}")

print("\nAnti-leakage principles:")
print("  1. ✓ Fully separated by subject")
print("  2. ✓ Statistics computed from training fold only")
print("  3. ✓ Feature engineering is fold-internal")
print("  4. ✓ Window slicing performed after splitting")
print("  5. ✓ Hyperparameter tuning limited to training data")
print("  6. ✓ Test set used strictly for independent evaluation")
print("  7. ✓ Cross-fold aggregation uses independent metrics")

print("\nNext steps:")
print("  - Set export FOLD_ID=<fold_id>")
print("  - Re-run steps 6–9 (per-fold processing)")
print("  - Train and evaluate models")
print("  - Iterate all folds and aggregate results")

print("="*60)

Step 10: LOSO split

1. Load data and get subject list
Loading data: data/lara/mbientlab/proc/labeled/
Data shape: (556504, 16)
Total samples: 556,504

Subject list:
  Total: 8 subjects
  IDs: ['S07', 'S08', 'S09', 'S10', 'S11', 'S12', 'S13', 'S14']

Sample count per subject:
  S07:   76,522 samples (13.75%)
  S08:   64,857 samples (11.65%)
  S09:   77,701 samples (13.96%)
  S10:   82,659 samples (14.85%)
  S11:   70,410 samples (12.65%)
  S12:   30,923 samples ( 5.56%)
  S13:   82,335 samples (14.80%)
  S14:   71,097 samples (12.78%)

2. Generate LOSO split

LOSO strategy: Leave-One-Subject-Out
  #folds = #subjects = 8
  Per fold: 1 subject for test, 7 subjects for train
  Fold 0: test S07, train 7 subjects
  Fold 1: test S08, train 7 subjects
  Fold 2: test S09, train 7 subjects
  Fold 3: test S10, train 7 subjects
  Fold 4: test S11, train 7 subjects
  Fold 5: test S12, train 7 subjects
  Fold 6: test S13, train 7 subjects
  Fold 7: test S14, train 7 subjects

✓ Generated 8 LOSO fol

In [13]:
# =============================================
# Step 11: rTsfNet (IMWUT 2024 official architecture aligned · multi-head rotation
#                  parameters estimated via TSF-Mixer · block-wise multi-scale TSF ·
#                  axis/channel binary selection · label injection · rotation
#                  parameters accumulated across heads)
#                  — compatible with arbitrary window length T (automatic symmetric
#                  padding, Keras 3 / Graph safe, full shape inference)
#                  — Dataset: LARa MbientLab (LOSO over subjects)
# =============================================
import os, json, random, math, warnings
warnings.filterwarnings("ignore")

# !pip -q install "tensorflow==2.15.1"
import numpy as np
import pandas as pd
from pathlib import Path

import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.layers import (
    Dense, Dropout, LayerNormalization, LeakyReLU,
    Layer, Activation, TimeDistributed, Flatten, Concatenate
)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.regularizers import l2
from sklearn.metrics import f1_score, accuracy_score

# ---- Random seeds ----
SEED = 42
tf.random.set_seed(SEED); np.random.seed(SEED); random.seed(SEED)

print("\n\nStep 11: rTsfNet (IMWUT 2024) official architecture-aligned version "
      "— LARa MbientLab LOSO (supports arbitrary T, Keras 3 safe)")
print("=" * 76)

# ==================== Configurable hyperparameters ====================
FS = 50.0                 # Sampling rate (Hz)
IMU_ROT_HEADS = 2         # Number of multi-head 3D rotation blocks
MLP_BASE = 128            # Base width of MLP
MLP_DEPTH = 3             # Number of MLP layers
DROPOUT = 0.5
LR = 1e-3
WEIGHT_DECAY = 1e-6

BOOTSTRAP_EPOCHS = 150
TOTAL_EPOCHS = 350
BATCH_SIZE = 32
PATIENCE = 50
USE_ORIG_INPUT = True         # Whether to keep original input (+L2) stream
USE_BINARY_SELECTION = True   # Whether to enable axis/channel binary selection
LN_EPS = 1e-7                 # Recommended value in official README (TF2.15 LayerNorm bug)
PAD_MODE = 'SYMMETRIC'        # 'SYMMETRIC' / 'REFLECT' / 'CONSTANT'

# === Block set configuration ===
BLOCK_SPECS = [
    dict(name='short', num_blocks=4, use_time=True,  use_freq=False),  # primarily time-domain
    dict(name='long',  num_blocks=1, use_time=False, use_freq=True),   # primarily frequency-domain
]

# ==================== Paths and configuration (LARa version) ====================
BASE = Path('/content')           # Colab the default working directory
configs_dir = BASE / 'configs'
windows_root = BASE / 'data' / 'lara' / 'mbientlab' / 'proc' / 'windows'

models_root = BASE / 'models'
models_dir = models_root / 'rtsfnet_official_lara_step11'
models_dir.mkdir(parents=True, exist_ok=True)

# ---- Load label configuration (from Step 8: labels.json) ----
labels_json = configs_dir / 'labels.json'
with open(labels_json, 'r') as f:
    labels_cfg = json.load(f)

# unified_labels: {label_name: id}
unified_labels = labels_cfg['unified_labels']
id_to_name = {int(v): k for k, v in unified_labels.items()}

# Align the number of classes using “max ID + 1” (the transition class may have no samples, but its output dimension is still reserved)
n_classes = max(id_to_name.keys()) + 1

print(f"\nNumber of classes (max ID + 1): {n_classes}")
print("Class ID → name mapping (from labels.json):")
for cid in sorted(id_to_name.keys()):
    print(f"  {cid:2d}: {id_to_name[cid]}")

# ---- Load LOSO splits (from Step 10: splits.json) ----
splits_path = configs_dir / 'splits.json'
with open(splits_path, 'r') as f:
    splits_cfg = json.load(f)

fold_ids = sorted(int(k) for k in splits_cfg.keys())
print(f"\nDetected {len(fold_ids)} folds from {splits_path}: {fold_ids}")

# ==================== Data loading for LARa windows ====================
def load_fold_data(fold_k: int, windows_root: Path):
    """
    Load data for one LOSO fold of LARa:
      - Uses sliding-window outputs from Step 9:
        data/lara/mbientlab/proc/windows/fold_xx/X_train.npy, y_train.npy, X_test.npy, y_test.npy
      - Windows have 8 channels: [ax, ay, az, gx, gy, gz, acc_mag, gyr_mag]
      - For rTsfNet Only keep the first 6 channels (ACC+GYR), so that the model input remains [B, T, 6].
    """
    fold_tag = f"fold_{fold_k:02d}"
    fold_dir = windows_root / fold_tag

    X_train = np.load(fold_dir / 'X_train.npy')   # [N_train, T, 8]
    y_train = np.load(fold_dir / 'y_train.npy')   # [N_train]
    X_test  = np.load(fold_dir / 'X_test.npy')    # [N_test,  T, 8]
    y_test  = np.load(fold_dir / 'y_test.npy')    # [N_test]

    # Keep only the 6 raw channels: ax, ay, az, gx, gy, gz.
    if X_train.shape[-1] != 6:
        X_train = X_train[..., :6]
        X_test  = X_test[..., :6]

    # Cast to float32
    X_train = X_train.astype('float32')
    X_test  = X_test.astype('float32')

    return X_train, y_train, X_test, y_test

# ==================== Shared: Keras 3–safe MLP stack ====================
class MLPStack(Layer):
    """
    Dense -> LayerNorm -> LeakyReLU -> Dropout repeated 'depth' times,
    hidden width base_kn * (2**k), k: depth-1..0; output dimensionality
    is fixed to base_kn.
    """
    def __init__(self, base_kn=128, depth=3, drop=0.5, wd=0.0, ln_eps=1e-7, name=None):
        super().__init__(name=name)
        self.base_kn = int(base_kn)
        self.depth = int(depth)
        self.drop = float(drop)
        self.wd = float(wd)
        self.ln_eps = float(ln_eps)

        self.seq = []
        for k in range(self.depth - 1, -1, -1):
            self.seq.append(Dense(self.base_kn * (2**k), kernel_regularizer=l2(self.wd)))
            self.seq.append(LayerNormalization(epsilon=self.ln_eps))
            self.seq.append(LeakyReLU())
            self.seq.append(Dropout(self.drop))

    @property
    def out_dim(self):
        return self.base_kn

    def call(self, x, training=None):
        z = x
        for lyr in self.seq:
            if isinstance(lyr, Dropout):
                z = lyr(z, training=training)
            else:
                z = lyr(z)
        return z

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], self.out_dim])

# ==================== TSF extraction (axis-wise) ====================
TIME_FEATS = 12  # mean/std/max/min/ptp/rms/energy/skew/kurt/zcr/ar1/ar2
FREQ_FEATS = 7   # centroid/entropy/flatness/soft-peak/bandpowers(3)

class TSFFeatureLayer(Layer):
    """Compute axis-wise TSF features for a single block [B, L, C];
    output shape [B, C, F] where F is the fixed TSF feature dimensionality."""
    def __init__(self, fs=50.0, use_time=True, use_freq=True, **kwargs):
        super().__init__(**kwargs)
        self.fs = float(fs)
        self.use_time = bool(use_time)
        self.use_freq = bool(use_freq)
        self.eps = 1e-8
        self._feat_dim = (TIME_FEATS if self.use_time else 0) + (FREQ_FEATS if self.use_freq else 0)

    def get_config(self):
        cfg = super().get_config()
        cfg.update({'fs': self.fs, 'use_time': self.use_time, 'use_freq': self.use_freq})
        return cfg

    def call(self, x):  # x: [B, L, C]
        feats = []
        if self.use_time:
            mean = tf.reduce_mean(x, axis=1, keepdims=True)
            std  = tf.math.reduce_std(x, axis=1, keepdims=True) + self.eps
            maxv = tf.reduce_max(x, axis=1, keepdims=True)
            minv = tf.reduce_min(x, axis=1, keepdims=True)
            ptp  = maxv - minv
            rms  = tf.sqrt(tf.reduce_mean(tf.square(x), axis=1, keepdims=True))
            energy = tf.reduce_sum(tf.square(x), axis=1, keepdims=True)
            skew = tf.reduce_mean(tf.pow((x - mean) / std, 3), axis=1, keepdims=True)
            kurt = tf.reduce_mean(tf.pow((x - mean) / std, 4), axis=1, keepdims=True)
            signs = tf.sign(x)
            sign_changes = tf.abs(signs[:, 1:, :] - signs[:, :-1, :])
            zcr = tf.reduce_mean(sign_changes, axis=1, keepdims=True) / 2.0
            x_t1 = x[:, :-1, :]; x_tn1 = x[:, 1:, :]
            ar1 = tf.reduce_sum(x_t1 * x_tn1, axis=1, keepdims=True) / (
                tf.reduce_sum(tf.square(x_t1), axis=1, keepdims=True) + self.eps
            )
            x_t2 = x[:, :-2, :]; x_tn2 = x[:, 2:, :]
            ar2 = tf.reduce_sum(x_t2 * x_tn2, axis=1, keepdims=True) / (
                tf.reduce_sum(tf.square(x_t2), axis=1, keepdims=True) + self.eps
            )
            feats += [mean, std, maxv, minv, ptp, rms, energy, skew, kurt, zcr, ar1, ar2]

        if self.use_freq:
            mean = tf.reduce_mean(x, axis=1, keepdims=True)
            xc = x - mean
            x_bc_t = tf.transpose(xc, [0, 2, 1])               # [B, C, L]
            fft = tf.signal.rfft(x_bc_t)                      # [B, C, F]
            power = tf.square(tf.abs(fft)) + self.eps         # [B, C, F]
            power = tf.transpose(power, [0, 2, 1])            # [B, F, C]

            F = tf.shape(power)[1]
            freqs = tf.linspace(0.0, tf.cast(self.fs, tf.float32) / 2.0, F)  # [F]
            freqs = tf.reshape(freqs, [1, F, 1])                             # [1, F, 1]

            p = power / (tf.reduce_sum(power, axis=1, keepdims=True) + self.eps)
            centroid = tf.reduce_sum(p * freqs, axis=1, keepdims=True)       # [B, 1, C]
            entropy  = -tf.reduce_sum(p * tf.math.log(p + self.eps), axis=1, keepdims=True) / \
                        (tf.math.log(tf.cast(F, tf.float32) + self.eps))
            geo = tf.exp(tf.reduce_mean(tf.math.log(power), axis=1, keepdims=True))
            ari = tf.reduce_mean(power, axis=1, keepdims=True)
            flatness = geo / (ari + self.eps)
            temp = 10.0
            w = tf.nn.softmax(power * temp, axis=1)                          # [B, F, C]
            soft_peak = tf.reduce_sum(w * freqs, axis=1, keepdims=True)      # [B, 1, C]

            def band(low, high):
                mask = tf.cast((freqs >= low) & (freqs < high), tf.float32)
                bp = tf.reduce_sum(power * mask, axis=1, keepdims=True) / (
                    tf.reduce_sum(power, axis=1, keepdims=True) + self.eps
                )
                return bp

            bp1 = band(0.5, 3.0); bp2 = band(3.0, 8.0); bp3 = band(8.0, 15.0)
            feats += [centroid, entropy, flatness, soft_peak, bp1, bp2, bp3]

        res = tf.concat(feats, axis=1)                       # [B, Fnum, C]
        return tf.transpose(res, [0, 2, 1])                  # [B, C, Fnum]

    def compute_output_shape(self, input_shape):
        # input_shape: (B, L, C) -> (B, C, Fnum)
        return tf.TensorShape([input_shape[0], input_shape[2], self._feat_dim])

# ==================== Utility: L2-norm channels (wrapped as a Layer) ====================
class AddL2Channels(Layer):
    def call(self, x, training=None):
        acc = x[:, :, :3]; gyr = x[:, :, 3:6]
        l2_acc = tf.sqrt(tf.reduce_sum(tf.square(acc), axis=-1, keepdims=True))
        l2_gyr = tf.sqrt(tf.reduce_sum(tf.square(gyr), axis=-1, keepdims=True))
        return tf.concat([x, l2_acc, l2_gyr], axis=-1)  # [B, T, 8]

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0], input_shape[1], 8])

# ==================== Framing into blocks (Graph/Keras 3 safe) ====================
def _int_ceil_div(a, b):
    a = tf.cast(a, tf.int32); b = tf.cast(b, tf.int32)
    return tf.math.floordiv(a + b - 1, b)

def frame_signal_with_padding(x, num_blocks, pad_mode='SYMMETRIC'):
    """
    [B, T, C] -> symmetric padding to length L * num_blocks
    and reshape to [B, num_blocks, L, C].
    """
    B = tf.shape(x)[0]; T = tf.shape(x)[1]; C = tf.shape(x)[2]
    nb = tf.cast(num_blocks, tf.int32)
    L  = _int_ceil_div(T, nb)
    total = L * nb
    pad_len = total - T
    pad_left  = tf.math.floordiv(pad_len, 2)
    pad_right = pad_len - pad_left
    paddings = tf.stack([
        tf.constant([0, 0], dtype=tf.int32),
        tf.stack([pad_left, pad_right]),
        tf.constant([0, 0], dtype=tf.int32)
    ], axis=0)  # [3, 2]
    x_pad = tf.pad(x, paddings, mode=pad_mode)             # even if pad_len==0, still goes through tf.pad
    x_blocks = tf.reshape(x_pad, [B, nb, L, C])
    return x_blocks

class BlockTSFExtractor(Layer):
    """
    Apply TSF extraction and axis-tag injection for a block set.
    Input:  x with shape [B, T, C_total]
    Output: TSF tensor [B, num_blocks, A, F_total] (A = C_total; last dimension includes tags).
    """
    def __init__(self, num_blocks, fs, use_time, use_freq,
                 tag_spec=None, pad_mode='SYMMETRIC', name=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.num_blocks = int(num_blocks)
        self.tsf = TSFFeatureLayer(fs=fs, use_time=use_time, use_freq=use_freq)
        self.tag_spec = tag_spec  # dict: {'axis_tags': [A, tag_dim]}
        self.pad_mode = pad_mode
        self.tag_dim = 0 if (tag_spec is None or 'axis_tags' not in tag_spec) else int(tag_spec['axis_tags'].shape[1])
        self.base_feat_dim = (TIME_FEATS if use_time else 0) + (FREQ_FEATS if use_freq else 0)
        self.out_feat_dim = self.base_feat_dim + self.tag_dim

    def get_config(self):
        cfg = super().get_config()
        cfg.update({'num_blocks': self.num_blocks, 'fs': self.tsf.fs,
                    'use_time': self.tsf.use_time, 'use_freq': self.tsf.use_freq,
                    'pad_mode': self.pad_mode})
        return cfg

    def call(self, x, training=None):  # x: [B, T, C]
        xb = frame_signal_with_padding(x, self.num_blocks, pad_mode=self.pad_mode)  # [B, K, L, C]
        B = tf.shape(xb)[0]; K = tf.shape(xb)[1]; L = tf.shape(xb)[2]; C = tf.shape(xb)[3]
        xb2 = tf.reshape(xb, [B * K, L, C])                   # [B*K, L, C]
        tsf_axis = self.tsf(xb2)                             # [B*K, C, F]
        tsf_axis = tf.reshape(tsf_axis, [B, K, C, self.base_feat_dim])  # [B, K, A, F_base]

        # Label injection (concatenate axis tags along the last dimension)
        if self.tag_dim > 0:
            axis_tags = tf.convert_to_tensor(self.tag_spec['axis_tags'], dtype=tsf_axis.dtype)  # [A, tag_dim]
            axis_tags = tf.reshape(axis_tags, [1, 1, tf.shape(tsf_axis)[2], -1])  # [1, 1, A, tag_dim]
            axis_tags = tf.tile(axis_tags, [B, K, 1, 1])                          # [B, K, A, tag_dim]
            tsf_axis = tf.concat([tsf_axis, axis_tags], axis=-1)                  # [B, K, A, F_base+tag_dim]
        return tsf_axis  # [B, K, A, F_total]

    def compute_output_shape(self, input_shape):
        # input_shape: (B, T, C) -> (B, K, A(=C), F_total)
        return tf.TensorShape([input_shape[0], self.num_blocks, input_shape[2], self.out_feat_dim])

# ==================== Binary gate (straight-through estimator) ====================
class BinaryGate(Layer):
    def call(self, p, training=None):
        p = tf.clip_by_value(p, 0.0, 1.0)
        hard = tf.round(p)
        return hard + tf.stop_gradient(p - hard)

    def compute_output_shape(self, input_shape):
        return tf.TensorShape(input_shape)

# ==================== TSF-Mixer sub-Block & Block ====================
class TSFMixerSubBlock(Layer):
    """
    Input: per-block axis-level TSF features with shape [B', A, F]
    Architecture: axis-shared MLP -> concatenate axes -> MLP
    Output: block-level feature [B', H_out]
    """
    def __init__(self, axis_hidden=128, out_hidden=128, base_depth=2,
                 drop=0.5, wd=0.0, ln_eps=1e-7, name=None):
        super().__init__(name=name)
        self.axis_hidden = int(axis_hidden)
        self.out_hidden = int(out_hidden)
        self.base_depth = int(base_depth)
        self.drop = float(drop); self.wd = float(wd); self.ln_eps = float(ln_eps)
        # Axis-wise shared MLP
        self.axis_mlp_layers = []
        for k in range(self.base_depth - 1, -1, -1):
            self.axis_mlp_layers.append(Dense(self.axis_hidden * (2**k), kernel_regularizer=l2(self.wd)))
            self.axis_mlp_layers.append(LayerNormalization(epsilon=self.ln_eps))
            self.axis_mlp_layers.append(LeakyReLU())
            self.axis_mlp_layers.append(Dropout(self.drop))
        # Output MLP (actual layer stack)
        self.out_stack = MLPStack(base_kn=self.out_hidden, depth=self.base_depth,
                                  drop=self.drop, wd=self.wd, ln_eps=self.ln_eps, name=f'{self.name}_out')

    def call(self, x, training=None, **kwargs):  # x: [B', A, F]
        Bp = tf.shape(x)[0]; A = tf.shape(x)[1]; F = tf.shape(x)[2]
        x2 = tf.reshape(x, [Bp * A, F])
        z = x2
        for lyr in self.axis_mlp_layers:
            if isinstance(lyr, Dropout):
                z = lyr(z, training=training)
            else:
                z = lyr(z)
        # Key fix: axis representation dimension is fixed to self.axis_hidden to avoid None shapes
        z = tf.reshape(z, [Bp, A, self.axis_hidden])   # [B', A, H_axis]
        z = tf.reshape(z, [Bp, A * self.axis_hidden])  # concatenate all axes [B', A*H_axis]
        z = self.out_stack(z, training=training)       # [B', H_out]
        return z

    def compute_output_shape(self, input_shape):
        # input_shape: (B', A, F) -> (B', out_hidden)
        return tf.TensorShape([input_shape[0], self.out_stack.out_dim])

class TSFMixerBlock(Layer):
    """
    Extends the sub-Block with:
      - channel-wise binary selection (over feature dimension F)
      - axis-wise binary selection (over axis dimension A)
    Output shape: [B', out_hidden]
    """
    def __init__(self, feat_dim, axis_hidden=128, out_hidden=128, base_depth=2,
                 drop=0.5, wd=0.0, ln_eps=1e-7, use_binary=True, name=None):
        super().__init__(name=name)
        self.use_binary = bool(use_binary)
        self.sub = TSFMixerSubBlock(axis_hidden, out_hidden, base_depth, drop, wd, ln_eps,
                                    name=f'{name}_sub')
        self.axis_gate_dense = Dense(1, activation='sigmoid', name=f'{name}_axis_gate')
        self.chan_gate_dense = Dense(int(feat_dim), activation='sigmoid', name=f'{name}_chan_gate')
        self.bin_gate = BinaryGate(name=f'{name}_bin')
        self.out_stack = MLPStack(base_kn=out_hidden, depth=base_depth,
                                  drop=drop, wd=wd, ln_eps=ln_eps, name=f'{name}_out')

    def call(self, x, training=None, **kwargs):  # x: [B', A, F]
        Bp = tf.shape(x)[0]; A = tf.shape(x)[1]; F = tf.shape(x)[2]

        # ---- Channel-wise gate (shared across axes)
        x_mean_axis = tf.reduce_mean(x, axis=1)            # [B', F]
        p_chan = self.chan_gate_dense(x_mean_axis)         # [B', F]
        p_chan = tf.reshape(p_chan, [Bp, 1, F])            # broadcast to all axes
        g_chan = self.bin_gate(p_chan, training=training) if self.use_binary else p_chan
        x = x * g_chan                                     # zero-out pruned channels

        # ---- Axis-shared MLP to obtain axis representations (shared weights)
        #      and estimate axis gate
        x2 = tf.reshape(x, [Bp * A, F])
        z = x2
        for lyr in self.sub.axis_mlp_layers:
            if isinstance(lyr, Dropout):
                z = lyr(z, training=training)
            else:
                z = lyr(z)
        # Key fix: axis representation dimension is fixed to self.sub.axis_hidden to avoid None shapes
        z = tf.reshape(z, [Bp, A, self.sub.axis_hidden])   # [B', A, H_axis]

        p_axis = self.axis_gate_dense(z)                   # [B', A, 1]
        g_axis = self.bin_gate(p_axis, training=training) if self.use_binary else p_axis
        z = z * g_axis                                     # zero-out pruned axes

        # ---- Concatenate axes -> output MLP
        z = tf.reshape(z, [Bp, A * self.sub.axis_hidden])  # [B', A*H_axis]
        z = self.out_stack(z, training=training)           # [B', H_out]
        return z

    def compute_output_shape(self, input_shape):
        # input_shape: (B', A, F) -> (B', out_hidden)
        return tf.TensorShape([input_shape[0], self.out_stack.out_dim])

# ==================== Rotation-parameter estimation block ====================
def _feat_dim_for_spec(use_time, use_freq, tag_dim):
    base = (TIME_FEATS if use_time else 0) + (FREQ_FEATS if use_freq else 0)
    return base + tag_dim

class RotationParamEstimator(Layer):
    """
    Input: [B, T, 6] raw IMU (ACC + GYR); internally append L2 channels,
    extract TSF (for multiple block sets), pass through TSF-Mixer,
    concatenate all block-level features, then MLP -> Dense(4, tanh)
    to produce Rodrigues 4-parameter representation.
    """
    def __init__(self, block_specs, fs, mlp_base=128, mlp_depth=2,
                 drop=0.5, wd=0.0, ln_eps=1e-7,
                 use_binary=True, pad_mode='SYMMETRIC', name=None):
        super().__init__(name=name)
        self.block_specs = block_specs
        self.fs = fs
        self.mlp_base = int(mlp_base)
        self.mlp_depth = int(mlp_depth)
        self.drop = float(drop)
        self.wd = float(wd)
        self.ln_eps = float(ln_eps)
        self.use_binary = bool(use_binary)
        self.pad_mode = pad_mode

        # Rotation stage: input channels fixed to 8 -> axis tags (axis_type, sensor_type)
        axis_tags = []
        for i in range(8):
            axis_type = i + 1
            sensor_type = 1 if (i <= 2 or i == 6) else 2  # acc & l2_acc=1; gyr & l2_gyr=2
            axis_tags.append([axis_type, sensor_type])
        axis_tags = np.array(axis_tags, dtype=np.float32)
        self.tag_spec = {'axis_tags': axis_tags}
        tag_dim = axis_tags.shape[1]

        # For each block set, build subgraph: Extractor -> TimeDistributed(Mixer) -> Flatten
        self.extractors = []
        self.td_mixers  = []
        self.flatteners = []
        for spec in block_specs:
            ext = BlockTSFExtractor(num_blocks=spec['num_blocks'], fs=fs,
                                    use_time=spec['use_time'], use_freq=spec['use_freq'],
                                    tag_spec=self.tag_spec, pad_mode=self.pad_mode,
                                    name=f'rot_ext_{spec["name"]}')
            self.extractors.append(ext)
            feat_dim = _feat_dim_for_spec(spec['use_time'], spec['use_freq'], tag_dim)
            mix = TSFMixerBlock(feat_dim=feat_dim, axis_hidden=self.mlp_base,
                                out_hidden=self.mlp_base,
                                base_depth=max(1, self.mlp_depth - 1),
                                drop=self.drop, wd=self.wd,
                                ln_eps=self.ln_eps, use_binary=self.use_binary,
                                name=f'rot_mix_{spec["name"]}')
            self.td_mixers.append(TimeDistributed(mix, name=f'rot_td_{spec["name"]}'))
            self.flatteners.append(Flatten(name=f'rot_flat_{spec["name"]}'))

        self.concat_sets = Concatenate(name='rot_concat_sets')
        self.post_stack = MLPStack(base_kn=self.mlp_base, depth=self.mlp_depth,
                                   drop=self.drop, wd=self.wd, ln_eps=self.ln_eps, name='rot_post')
        self.out_head = Dense(4, activation='tanh', name='rot4_tanh')
        self.add_l2 = AddL2Channels()

    def call(self, x, training=None, **kwargs):  # x: [B, T, 6]
        x8 = self.add_l2(x)  # [B, T, 8]
        feats_all = []
        for ext, td, flt in zip(self.extractors, self.td_mixers, self.flatteners):
            tsf_blocks = ext(x8, training=training)        # [B, K, A, F]
            blk_feat   = td(tsf_blocks, training=training) # [B, K, H]
            blk_feat   = flt(blk_feat)                     # [B, K*H]
            feats_all.append(blk_feat)
        h = self.concat_sets(feats_all)                    # concatenate all block sets [B, sum(K*H)]
        h = self.post_stack(h, training=training)
        rot4 = self.out_head(h)                            # tanh -> (-1, 1)
        return rot4  # [B, 4] -> [axis(xyz), angle]

    def compute_output_shape(self, input_shape):
        # Input (B, T, 6) -> output (B, 4)
        return tf.TensorShape([input_shape[0], 4])

# ==================== Multi-head 3D rotation (official: parameters accumulated across heads) ====================
class Multihead3DRotationOfficial(Layer):
    """
    Input [B, T, 6] (ACC + GYR); output: list whose each element is a
    rotated stream [B, T, 6]. Rotation parameters are estimated by
    RotationParamEstimator; for head index >= 2, the 4 parameters are
    accumulated over previous heads.
    """
    def __init__(self, head_nums=2, fs=50.0, mlp_base=128, mlp_depth=2,
                 drop=0.5, wd=0.0, ln_eps=1e-7,
                 block_specs=None, use_binary=True, pad_mode='SYMMETRIC', name=None):
        super().__init__(name=name)
        if block_specs is None:
            block_specs = BLOCK_SPECS
        self.head_nums = int(head_nums)
        self.estimator = RotationParamEstimator(block_specs=block_specs, fs=fs,
                                                mlp_base=mlp_base, mlp_depth=mlp_depth,
                                                drop=drop, wd=wd,
                                                ln_eps=ln_eps, use_binary=use_binary,
                                                pad_mode=pad_mode,
                                                name='rot_estimator')
        self.eps = 1e-8

    def compute_output_shape(self, input_shape):
        return [tf.TensorShape(input_shape) for _ in range(self.head_nums)]

    def call(self, x, training=None, **kwargs):  # x: [B, T, 6]
        acc, gyr = x[:, :, :3], x[:, :, 3:6]
        out_list = []
        prev_rot4 = None
        for _ in range(self.head_nums):
            rot4 = self.estimator(x, training=training)     # [B, 4]
            if prev_rot4 is not None:
                rot4 = rot4 + prev_rot4                    # parameter accumulation
            prev_rot4 = rot4
            axis = rot4[:, :3]; angle = tf.expand_dims(rot4[:, 3], -1)
            R = self._axis_angle_to_R(axis, angle)          # [B, 3, 3]

            acc_t = tf.transpose(acc, [0, 2, 1])            # [B, 3, T]
            acc_rot = tf.transpose(tf.matmul(R, acc_t), [0, 2, 1])
            gyr_t = tf.transpose(gyr, [0, 2, 1])
            gyr_rot = tf.transpose(tf.matmul(R, gyr_t), [0, 2, 1])

            out_list.append(tf.concat([acc_rot, gyr_rot], axis=-1))  # [B, T, 6]
        return out_list

    def _axis_angle_to_R(self, axis_raw, angle_raw):
        axis = axis_raw / (tf.norm(axis_raw, axis=-1, keepdims=True) + self.eps)  # [B, 3]
        theta = angle_raw * math.pi                                               # [B, 1]
        B = tf.shape(axis)[0]
        ux, uy, uz = axis[:, 0], axis[:, 1], axis[:, 2]
        z = tf.zeros_like(ux)
        K = tf.stack([z, -uz,  uy,
                      uz,  z, -ux,
                     -uy,  ux,  z], axis=-1)
        K = tf.reshape(K, [B, 3, 3])
        I  = tf.tile(tf.eye(3, dtype=axis.dtype)[None, ...], [B, 1, 1])
        u = tf.expand_dims(axis, -1)                                              # [B, 3, 1]
        uuT = tf.matmul(u, u, transpose_b=True)                                   # [B, 3, 3]
        cos = tf.reshape(tf.cos(theta), [-1, 1, 1])
        sin = tf.reshape(tf.sin(theta), [-1, 1, 1])
        R = cos * I + (1.0 - cos) * uuT + sin * K                                 # [B, 3, 3]
        return R

# ==================== Main rTsfNet body ====================
class AddL2ChannelsPublic(Layer):
    def call(self, x, training=None):
        acc = x[:, :, :3]; gyr = x[:, :, 3:6]
        l2_acc = tf.sqrt(tf.reduce_sum(tf.square(acc), axis=-1, keepdims=True))
        l2_gyr = tf.sqrt(tf.reduce_sum(tf.square(gyr), axis=-1, keepdims=True))
        return tf.concat([x, l2_acc, l2_gyr], axis=-1)  # [B, T, 8]

def r_tsf_net_official(x_shape, n_classes,
                       learning_rate=1e-3, base_kn=128, depth=3, dropout_rate=0.5,
                       imu_rot_heads=2, fs=50.0, use_orig_input=True,
                       use_binary_selection=True, ln_eps=1e-7, pad_mode='SYMMETRIC'):

    inputs = Input(shape=x_shape[1:])     # [T, 6]
    x = inputs

    # ---- Multi-head 3D rotation (parameters via TSF-Mixer; accumulated across heads;
    #      internal automatic block-wise padding)
    rot_layer = Multihead3DRotationOfficial(
        head_nums=imu_rot_heads, fs=fs,
        mlp_base=base_kn, mlp_depth=max(1, depth - 1), drop=dropout_rate, wd=WEIGHT_DECAY,
        ln_eps=ln_eps, block_specs=BLOCK_SPECS, use_binary=use_binary_selection, pad_mode=pad_mode,
        name='multihead_rot_official'
    )
    rotated_list = rot_layer(x)   # list of [B, T, 6]

    # ---- Optionally keep original input (+L2) and concatenate all streams
    streams = []
    add_l2 = AddL2ChannelsPublic()
    if use_orig_input:
        streams.append(add_l2(x))  # [B, T, 8]
    for xr in rotated_list:
        streams.append(add_l2(xr))
    concat_streams = Concatenate(axis=-1, name='concat_streams')(streams)  # [B, T, 8*(1+heads)]

    # ---- Backbone: multiple block sets -> TSF -> axis tags -> TimeDistributed(TSF-Mixer)
    #      -> Flatten -> concatenate -> MLP -> classification
    feats_all_sets = []
    # Build axis tags (axis_type & sensor_type), replicated for each stream of 8 channels
    num_streams = (1 if use_orig_input else 0) + imu_rot_heads
    axis_tags_one_stream = []
    for i in range(8):
        axis_type = i + 1
        sensor_type = 1 if (i <= 2 or i == 6) else 2
        axis_tags_one_stream.append([axis_type, sensor_type])
    axis_tags_one_stream = np.array(axis_tags_one_stream, dtype=np.float32)
    axis_tags_all = np.concatenate([axis_tags_one_stream for _ in range(num_streams)], axis=0)  # [8*num_streams, 2]
    tag_spec_main = {'axis_tags': axis_tags_all}
    tag_dim_main = axis_tags_all.shape[1]

    for spec in BLOCK_SPECS:
        ext = BlockTSFExtractor(num_blocks=spec['num_blocks'], fs=fs,
                                use_time=spec['use_time'], use_freq=spec['use_freq'],
                                tag_spec=tag_spec_main, pad_mode=pad_mode,
                                name=f'main_ext_{spec["name"]}')
        feat_dim = _feat_dim_for_spec(spec['use_time'], spec['use_freq'], tag_dim_main)
        mix = TSFMixerBlock(feat_dim=feat_dim, axis_hidden=base_kn, out_hidden=base_kn,
                            base_depth=max(1, depth - 1), drop=dropout_rate, wd=WEIGHT_DECAY,
                            ln_eps=ln_eps, use_binary=use_binary_selection,
                            name=f'main_mix_{spec["name"]}')
        td  = TimeDistributed(mix, name=f'main_td_{spec["name"]}')
        flt = Flatten(name=f'main_flat_{spec["name"]}')

        tsf_blocks = ext(concat_streams)   # [B, K, A_all, F]
        blk_feat   = td(tsf_blocks)        # [B, K, H]
        blk_feat   = flt(blk_feat)         # [B, K*H]
        feats_all_sets.append(blk_feat)

    z = Concatenate(name='main_concat_sets')(feats_all_sets)  # concatenate all block sets [B, sum(K*H)]
    cls_stack = MLPStack(base_kn=base_kn, depth=depth, drop=dropout_rate,
                         wd=WEIGHT_DECAY, ln_eps=ln_eps, name='cls')
    z = cls_stack(z)
    logits = Dense(n_classes, kernel_regularizer=l2(WEIGHT_DECAY), name='logits')(z)
    probs  = Activation('softmax', dtype='float32', name='softmax')(logits)

    model = Model(inputs, probs, name='rTsfNet_official_aligned')

    opt = Adam(learning_rate=learning_rate, amsgrad=True)
    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=opt,
        metrics=['accuracy']
    )
    return model

def _history_to_df(hist_obj):
    """Convert a Keras History object to a DataFrame and normalize
    acc/val_acc keys to accuracy/val_accuracy."""
    d = dict(hist_obj.history)
    if 'acc' in d and 'accuracy' not in d:
        d['accuracy'] = d.pop('acc')
    if 'val_acc' in d and 'val_accuracy' not in d:
        d['val_accuracy'] = d.pop('val_acc')
    return pd.DataFrame(d)

# ==================== Loop over all folds (LARa LOSO) ====================
all_results = []

for FOLD_TO_TRAIN in fold_ids:
    test_subject = splits_cfg[str(FOLD_TO_TRAIN)]['test_subject']
    print(f"\nTraining Fold {FOLD_TO_TRAIN} "
          f"(test subject: {test_subject})")
    print(f"Bootstrap epochs: {BOOTSTRAP_EPOCHS}, total epochs: {TOTAL_EPOCHS}, patience: {PATIENCE}")
    print("=" * 76)

    X_train, y_train, X_test, y_test = load_fold_data(FOLD_TO_TRAIN, windows_root)
    print(f"Train set: {X_train.shape}, test set: {X_test.shape}")

    model = r_tsf_net_official(
        x_shape=X_train.shape,
        n_classes=n_classes,
        learning_rate=LR,
        base_kn=MLP_BASE,
        depth=MLP_DEPTH,
        dropout_rate=DROPOUT,
        imu_rot_heads=IMU_ROT_HEADS,
        fs=FS,
        use_orig_input=USE_ORIG_INPUT,
        use_binary_selection=USE_BINARY_SELECTION,
        ln_eps=LN_EPS,
        pad_mode=PAD_MODE
    )

    print(f"\nTotal number of model parameters: {model.count_params():,}")
    model.summary(line_length=140)

    print(f"\nPhase 1: bootstrap training ({BOOTSTRAP_EPOCHS} epochs)...")
    history1 = model.fit(
        X_train, y_train,
        batch_size=BATCH_SIZE,
        epochs=BOOTSTRAP_EPOCHS,
        validation_data=(X_test, y_test),
        verbose=1
    )

    print(f"\nPhase 2: full training (additional {TOTAL_EPOCHS - BOOTSTRAP_EPOCHS} epochs)...")
    early_stop = EarlyStopping(monitor='val_accuracy', patience=PATIENCE,
                               restore_best_weights=True, verbose=1)
    reduce_lr  = ReduceLROnPlateau(monitor='val_loss', factor=0.5,
                                   patience=20, min_lr=1e-6, verbose=1)

    history2 = model.fit(
        X_train, y_train,
        batch_size=BATCH_SIZE,
        epochs=TOTAL_EPOCHS - BOOTSTRAP_EPOCHS,
        validation_data=(X_test, y_test),
        callbacks=[early_stop, reduce_lr],
        verbose=1
    )

    # ---- Evaluation ----
    y_pred = model.predict(X_test, verbose=0)
    y_hat  = np.argmax(y_pred, axis=1)

    test_acc         = accuracy_score(y_test, y_hat)
    test_f1_macro    = f1_score(y_test, y_hat, average='macro')
    test_f1_weighted = f1_score(y_test, y_hat, average='weighted')

    print("\n" + "=" * 76)
    print(f"Final evaluation for Fold {FOLD_TO_TRAIN}:")
    print(f"  Accuracy: {test_acc * 100:.2f}%")
    print(f"  LOSO Macro-F1: {test_f1_macro:.4f}")
    print(f"  Weighted F1: {test_f1_weighted:.4f}")
    print("=" * 76)

    # ---- Save weights & history ----
    model_path = models_dir / f'model_fold{FOLD_TO_TRAIN}.weights.h5'
    model.save_weights(model_path)
    print(f"\n✓ Model weights saved: {model_path}")

    h1_df = _history_to_df(history1)
    h1_df['epoch'] = np.arange(1, len(h1_df) + 1)
    h1_df['phase'] = 'bootstrap'
    h2_df = _history_to_df(history2)
    h2_df['epoch'] = np.arange(len(h1_df) + 1, len(h1_df) + len(h2_df) + 1)
    h2_df['phase'] = 'stage2'
    hist_df = pd.concat([h1_df, h2_df], ignore_index=True, sort=True)
    front_cols = [c for c in ['epoch', 'phase'] if c in hist_df.columns]
    hist_df = hist_df[front_cols + [c for c in hist_df.columns if c not in front_cols]]

    hist_csv = models_dir / f'history_fold{FOLD_TO_TRAIN}.csv'
    hist_df.to_csv(hist_csv, index=False)
    print(f"✓ Training history saved: {hist_csv}")

    # ---- Save per-fold metrics JSON ----
    results = {
        'fold': FOLD_TO_TRAIN,
        'test_subject': test_subject,
        'accuracy': float(test_acc),
        'macro_f1': float(test_f1_macro),
        'weighted_f1': float(test_f1_weighted),
        'history_rows': int(len(hist_df)),
        'config': {
            'fs': FS, 'imu_rot_heads': IMU_ROT_HEADS, 'mlp_base': MLP_BASE,
            'mlp_depth': MLP_DEPTH, 'dropout': DROPOUT, 'lr': LR, 'weight_decay': WEIGHT_DECAY,
            'use_orig_input': USE_ORIG_INPUT, 'epochs': TOTAL_EPOCHS, 'bootstrap': BOOTSTRAP_EPOCHS,
            'patience': PATIENCE, 'batch_size': BATCH_SIZE,
            'use_binary_selection': USE_BINARY_SELECTION, 'block_specs': BLOCK_SPECS,
            'pad_mode': PAD_MODE
        }
    }
    with open(models_dir / f'fold{FOLD_TO_TRAIN}_results.json', 'w') as f:
        json.dump(results, f, indent=2)
    print(f"✓ Evaluation results saved: {models_dir / f'fold{FOLD_TO_TRAIN}_results.json'}")

    all_results.append({
        'fold': FOLD_TO_TRAIN,
        'test_subject': test_subject,
        'accuracy': float(test_acc),
        'macro_f1': float(test_f1_macro),
        'weighted_f1': float(test_f1_weighted)
    })

    tf.keras.backend.clear_session()

# ==================== Cross-fold summary ====================
print("\n" + "=" * 76)
print("All LARa folds trained. Summary of results:")
print("=" * 76)
summary_df = pd.DataFrame(all_results)
print(summary_df)
print(f"\nMean accuracy: {summary_df['accuracy'].mean() * 100:.2f}%")
print(f"Mean Macro-F1: {summary_df['macro_f1'].mean():.4f}")
print(f"Mean Weighted-F1: {summary_df['weighted_f1'].mean():.4f}")

summary_csv = models_dir / 'all_folds_summary.csv'
summary_df.to_csv(summary_csv, index=False)
print(f"\n✓ Summary results saved: {summary_csv}")

print("\n" + "=" * 76 +
      "\nStep 11 finished (IMWUT 2024 official rTsfNet · LARa MbientLab LOSO · "
      "adaptive block-wise TSF · Keras 3 safe · full shape inference)\n" +
      "=" * 76)



Step 11: rTsfNet (IMWUT 2024) official architecture-aligned version — LARa MbientLab LOSO (supports arbitrary T, Keras 3 safe)

Number of classes (max ID + 1): 13
Class ID → name mapping (from labels.json):
   0: transition
   1: walking
   2: running
   3: sitting
   4: standing
   5: upstairs
   6: downstairs
   7: lying
   8: cycling
   9: car
  10: bus
  11: train
  12: subway

Detected 8 folds from /content/configs/splits.json: [0, 1, 2, 3, 4, 5, 6, 7]

Training Fold 0 (test subject: S07)
Bootstrap epochs: 150, total epochs: 350, patience: 50
Train set: (4965, 150, 6), test set: (766, 150, 6)

Total number of model parameters: 961,343



Phase 1: bootstrap training (150 epochs)...
Epoch 1/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m170s[0m 519ms/step - accuracy: 0.3819 - loss: 1.8620 - val_accuracy: 0.4922 - val_loss: 1.4891
Epoch 2/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5306 - loss: 1.3779 - val_accuracy: 0.5901 - val_loss: 1.2400
Epoch 3/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5908 - loss: 1.2340 - val_accuracy: 0.5992 - val_loss: 1.2233
Epoch 4/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6072 - loss: 1.2025 - val_accuracy: 0.6031 - val_loss: 1.1847
Epoch 5/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6081 - loss: 1.1758 - val_accuracy: 0.5966 - val_loss: 1.1918
Epoch 6/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6105 - loss: 1.1389 - val_accuracy:


Phase 1: bootstrap training (150 epochs)...
Epoch 1/150
[1m159/159[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m176s[0m 557ms/step - accuracy: 0.3482 - loss: 1.9563 - val_accuracy: 0.5554 - val_loss: 1.2878
Epoch 2/150
[1m159/159[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5151 - loss: 1.4164 - val_accuracy: 0.5842 - val_loss: 1.2470
Epoch 3/150
[1m159/159[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5708 - loss: 1.2975 - val_accuracy: 0.5857 - val_loss: 1.2103
Epoch 4/150
[1m159/159[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.5866 - loss: 1.2185 - val_accuracy: 0.5873 - val_loss: 1.1957
Epoch 5/150
[1m159/159[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5960 - loss: 1.1982 - val_accuracy: 0.5797 - val_loss: 1.1632
Epoch 6/150
[1m159/159[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5965 - loss: 1.1671 - val_accuracy:


Phase 1: bootstrap training (150 epochs)...
Epoch 1/150
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m109s[0m 151ms/step - accuracy: 0.3427 - loss: 1.9308 - val_accuracy: 0.6498 - val_loss: 1.0788
Epoch 2/150
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5360 - loss: 1.3538 - val_accuracy: 0.6667 - val_loss: 1.0760
Epoch 3/150
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5688 - loss: 1.2595 - val_accuracy: 0.6706 - val_loss: 1.0490
Epoch 4/150
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5771 - loss: 1.2312 - val_accuracy: 0.6744 - val_loss: 1.0154
Epoch 5/150
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5876 - loss: 1.1968 - val_accuracy: 0.6783 - val_loss: 0.9856
Epoch 6/150
[1m155/155[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5957 - loss: 1.1401 - val_accuracy:


Phase 1: bootstrap training (150 epochs)...
Epoch 1/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m172s[0m 550ms/step - accuracy: 0.3862 - loss: 1.8833 - val_accuracy: 0.5487 - val_loss: 1.3463
Epoch 2/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5549 - loss: 1.3265 - val_accuracy: 0.5601 - val_loss: 1.3028
Epoch 3/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5853 - loss: 1.2626 - val_accuracy: 0.5636 - val_loss: 1.2741
Epoch 4/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.5980 - loss: 1.2115 - val_accuracy: 0.5659 - val_loss: 1.2365
Epoch 5/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.6074 - loss: 1.1763 - val_accuracy: 0.5785 - val_loss: 1.2096
Epoch 6/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6178 - loss: 1.1581 - val_accuracy:


Phase 1: bootstrap training (150 epochs)...
Epoch 1/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m172s[0m 538ms/step - accuracy: 0.3913 - loss: 1.8768 - val_accuracy: 0.5818 - val_loss: 1.1898
Epoch 2/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5347 - loss: 1.3878 - val_accuracy: 0.6193 - val_loss: 1.1118
Epoch 3/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5776 - loss: 1.2798 - val_accuracy: 0.6206 - val_loss: 1.0975
Epoch 4/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5840 - loss: 1.2415 - val_accuracy: 0.6166 - val_loss: 1.0713
Epoch 5/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5891 - loss: 1.2188 - val_accuracy: 0.6153 - val_loss: 1.0447
Epoch 6/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5948 - loss: 1.1982 - val_accuracy:


Phase 1: bootstrap training (150 epochs)...
Epoch 1/150
[1m171/171[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m161s[0m 435ms/step - accuracy: 0.3576 - loss: 1.9389 - val_accuracy: 0.3668 - val_loss: 1.6219
Epoch 2/150
[1m171/171[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.4673 - loss: 1.5069 - val_accuracy: 0.4464 - val_loss: 1.5367
Epoch 3/150
[1m171/171[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5766 - loss: 1.2810 - val_accuracy: 0.4602 - val_loss: 1.5526
Epoch 4/150
[1m171/171[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5937 - loss: 1.1960 - val_accuracy: 0.4360 - val_loss: 1.5503
Epoch 5/150
[1m171/171[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5998 - loss: 1.1838 - val_accuracy: 0.4464 - val_loss: 1.5744
Epoch 6/150
[1m171/171[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6020 - loss: 1.1538 - val_accuracy:


Phase 1: bootstrap training (150 epochs)...
Epoch 1/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m169s[0m 544ms/step - accuracy: 0.3638 - loss: 1.9440 - val_accuracy: 0.6629 - val_loss: 1.0595
Epoch 2/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5622 - loss: 1.3280 - val_accuracy: 0.6708 - val_loss: 0.9953
Epoch 3/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5772 - loss: 1.2569 - val_accuracy: 0.6720 - val_loss: 0.9778
Epoch 4/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5907 - loss: 1.2063 - val_accuracy: 0.6743 - val_loss: 0.9604
Epoch 5/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5949 - loss: 1.1920 - val_accuracy: 0.7073 - val_loss: 0.9278
Epoch 6/150
[1m152/152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5963 - loss: 1.1632 - val_accuracy:


Phase 1: bootstrap training (150 epochs)...
Epoch 1/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m165s[0m 504ms/step - accuracy: 0.3869 - loss: 1.8001 - val_accuracy: 0.4833 - val_loss: 1.5490
Epoch 2/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.5819 - loss: 1.2733 - val_accuracy: 0.5087 - val_loss: 1.6131
Epoch 3/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6034 - loss: 1.1927 - val_accuracy: 0.5167 - val_loss: 1.5673
Epoch 4/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6134 - loss: 1.1496 - val_accuracy: 0.5220 - val_loss: 1.4673
Epoch 5/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6155 - loss: 1.1191 - val_accuracy: 0.5487 - val_loss: 1.3857
Epoch 6/150
[1m156/156[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.6243 - loss: 1.1034 - val_accuracy: